00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "kernel/SalzbergWordKernel.h"
00014 #include "features/Features.h"
00015 #include "features/StringFeatures.h"
00016 #include "features/Labels.h"
00017 #include "classifier/PluginEstimate.h"
00018
00019 CSalzbergWordKernel::CSalzbergWordKernel(INT size, CPluginEstimate* pie)
00020 : CStringKernel<WORD>(size), estimate(pie), mean(NULL), variance(NULL),
00021 sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00022 ld_mean_lhs(NULL), ld_mean_rhs(NULL),
00023 num_params(0), num_symbols(0), sum_m2_s2(0), pos_prior(0.5),
00024 neg_prior(0.5), initialized(false)
00025 {
00026 }
00027
00028 CSalzbergWordKernel::CSalzbergWordKernel(
00029 CStringFeatures<WORD>* l, CStringFeatures<WORD>* r, CPluginEstimate* pie)
00030 : CStringKernel<WORD>(10),estimate(pie), mean(NULL), variance(NULL),
00031 sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00032 ld_mean_lhs(NULL), ld_mean_rhs(NULL),
00033 num_params(0), num_symbols(0), sum_m2_s2(0), pos_prior(0.5),
00034 neg_prior(0.5), initialized(false)
00035 {
00036 init(l, r);
00037 }
00038
00039 CSalzbergWordKernel::~CSalzbergWordKernel()
00040 {
00041 cleanup();
00042 }
00043
00044 bool CSalzbergWordKernel::init(CFeatures* p_l, CFeatures* p_r)
00045 {
00046 bool status=CStringKernel<WORD>::init(p_l,p_r);
00047 CStringFeatures<WORD>* l=(CStringFeatures<WORD>*) p_l;
00048 ASSERT(l);
00049 CStringFeatures<WORD>* r=(CStringFeatures<WORD>*) p_r;
00050 ASSERT(r);
00051
00052 INT i;
00053 initialized=false;
00054
00055 if (sqrtdiag_lhs!=sqrtdiag_rhs)
00056 delete[] sqrtdiag_rhs;
00057 sqrtdiag_rhs=NULL;
00058 delete[] sqrtdiag_lhs;
00059 sqrtdiag_lhs=NULL;
00060 if (ld_mean_lhs!=ld_mean_rhs)
00061 delete[] ld_mean_rhs;
00062 ld_mean_rhs=NULL;
00063 delete[] ld_mean_lhs;
00064 ld_mean_lhs=NULL;
00065
00066 sqrtdiag_lhs=new DREAL[l->get_num_vectors()];
00067 ld_mean_lhs=new DREAL[l->get_num_vectors()];
00068
00069 for (i=0; i<l->get_num_vectors(); i++)
00070 sqrtdiag_lhs[i]=1;
00071
00072 if (l==r)
00073 {
00074 sqrtdiag_rhs=sqrtdiag_lhs;
00075 ld_mean_rhs=ld_mean_lhs;
00076 }
00077 else
00078 {
00079 sqrtdiag_rhs=new DREAL[r->get_num_vectors()];
00080 for (i=0; i<r->get_num_vectors(); i++)
00081 sqrtdiag_rhs[i]=1;
00082
00083 ld_mean_rhs=new DREAL[r->get_num_vectors()];
00084 }
00085
00086 DREAL* l_ld_mean_lhs=ld_mean_lhs;
00087 DREAL* l_ld_mean_rhs=ld_mean_rhs;
00088
00089
00090 if (!initialized)
00091 {
00092 INT num_vectors=l->get_num_vectors();
00093 num_symbols=(INT) l->get_num_symbols();
00094 INT llen=l->get_vector_length(0);
00095 INT rlen=r->get_vector_length(0);
00096 num_params=(INT) llen*l->get_num_symbols();
00097 INT num_params2=(INT) llen*l->get_num_symbols()+rlen*r->get_num_symbols();
00098 if ((!estimate) || (!estimate->check_models()))
00099 {
00100 SG_ERROR( "no estimate available\n");
00101 return false ;
00102 } ;
00103 if (num_params2!=estimate->get_num_params())
00104 {
00105 SG_ERROR( "number of parameters of estimate and feature representation do not match\n");
00106 return false ;
00107 } ;
00108
00109 delete[] variance;
00110 delete[] mean;
00111 mean=new DREAL[num_params];
00112 ASSERT(mean);
00113 variance=new DREAL[num_params];
00114 ASSERT(variance);
00115
00116 for (i=0; i<num_params; i++)
00117 {
00118 mean[i]=0;
00119 variance[i]=0;
00120 }
00121
00122
00123
00124 for (i=0; i<num_vectors; i++)
00125 {
00126 INT len;
00127 WORD* vec=l->get_feature_vector(i, len);
00128
00129 for (INT j=0; j<len; j++)
00130 {
00131 INT idx=compute_index(j, vec[j]);
00132 DREAL theta_p = 1/estimate->log_derivative_pos_obsolete(vec[j], j) ;
00133 DREAL theta_n = 1/estimate->log_derivative_neg_obsolete(vec[j], j) ;
00134 DREAL value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ;
00135
00136 mean[idx] += value/num_vectors ;
00137 }
00138 }
00139
00140
00141 for (i=0; i<num_vectors; i++)
00142 {
00143 INT len;
00144 WORD* vec=l->get_feature_vector(i, len);
00145
00146 for (INT j=0; j<len; j++)
00147 {
00148 for (INT k=0; k<4; k++)
00149 {
00150 INT idx=compute_index(j, k);
00151 if (k!=vec[j])
00152 variance[idx]+=mean[idx]*mean[idx]/num_vectors;
00153 else
00154 {
00155 DREAL theta_p = 1/estimate->log_derivative_pos_obsolete(vec[j], j) ;
00156 DREAL theta_n = 1/estimate->log_derivative_neg_obsolete(vec[j], j) ;
00157 DREAL value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ;
00158
00159 variance[idx] += CMath::sq(value-mean[idx])/num_vectors;
00160 }
00161 }
00162 }
00163 }
00164
00165
00166
00167 sum_m2_s2=0 ;
00168 for (i=0; i<num_params; i++)
00169 {
00170 if (variance[i]<1e-14)
00171 variance[i]=1 ;
00172
00173
00174 sum_m2_s2 += mean[i]*mean[i]/(variance[i]) ;
00175 } ;
00176 }
00177
00178
00179
00180
00181 for (i=0; i<l->get_num_vectors(); i++)
00182 {
00183 INT alen ;
00184 WORD* avec=l->get_feature_vector(i, alen);
00185 DREAL result=0 ;
00186 for (INT j=0; j<alen; j++)
00187 {
00188 INT a_idx = compute_index(j, avec[j]) ;
00189 DREAL theta_p = 1/estimate->log_derivative_pos_obsolete(avec[j], j) ;
00190 DREAL theta_n = 1/estimate->log_derivative_neg_obsolete(avec[j], j) ;
00191 DREAL value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ;
00192
00193 if (variance[a_idx]!=0)
00194 result-=value*mean[a_idx]/variance[a_idx];
00195 }
00196 ld_mean_lhs[i]=result ;
00197 }
00198
00199 if (ld_mean_lhs!=ld_mean_rhs)
00200 {
00201
00202
00203 for (i=0; i<r->get_num_vectors(); i++)
00204 {
00205 INT alen ;
00206 WORD* avec=r->get_feature_vector(i, alen);
00207 DREAL result=0 ;
00208 for (INT j=0; j<alen; j++)
00209 {
00210 INT a_idx = compute_index(j, avec[j]) ;
00211 DREAL theta_p = 1/estimate->log_derivative_pos_obsolete(avec[j], j) ;
00212 DREAL theta_n = 1/estimate->log_derivative_neg_obsolete(avec[j], j) ;
00213 DREAL value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ;
00214
00215 result -= value*mean[a_idx]/variance[a_idx] ;
00216 }
00217 ld_mean_rhs[i]=result ;
00218 } ;
00219 } ;
00220
00221
00222
00223 this->lhs=l;
00224 this->rhs=l;
00225 ld_mean_lhs = l_ld_mean_lhs ;
00226 ld_mean_rhs = l_ld_mean_lhs ;
00227
00228
00229 for (i=0; i<lhs->get_num_vectors(); i++)
00230 {
00231 sqrtdiag_lhs[i]=sqrt(compute(i,i));
00232
00233
00234 if (sqrtdiag_lhs[i]==0)
00235 sqrtdiag_lhs[i]=1e-16;
00236 }
00237
00238
00239
00240 if (sqrtdiag_lhs!=sqrtdiag_rhs)
00241 {
00242 this->lhs=r;
00243 this->rhs=r;
00244 ld_mean_lhs = l_ld_mean_rhs ;
00245 ld_mean_rhs = l_ld_mean_rhs ;
00246
00247
00248 for (i=0; i<rhs->get_num_vectors(); i++)
00249 {
00250 sqrtdiag_rhs[i]=sqrt(compute(i,i));
00251
00252
00253 if (sqrtdiag_rhs[i]==0)
00254 sqrtdiag_rhs[i]=1e-16;
00255 }
00256 }
00257
00258 this->lhs=l;
00259 this->rhs=r;
00260 ld_mean_lhs = l_ld_mean_lhs ;
00261 ld_mean_rhs = l_ld_mean_rhs ;
00262
00263 initialized = true ;
00264 return status;
00265 }
00266
00267 void CSalzbergWordKernel::cleanup()
00268 {
00269 delete[] variance;
00270 variance=NULL;
00271
00272 delete[] mean;
00273 mean=NULL;
00274
00275 if (sqrtdiag_lhs != sqrtdiag_rhs)
00276 delete[] sqrtdiag_rhs;
00277 sqrtdiag_rhs=NULL;
00278
00279 delete[] sqrtdiag_lhs;
00280 sqrtdiag_lhs=NULL;
00281
00282 if (ld_mean_lhs!=ld_mean_rhs)
00283 delete[] ld_mean_rhs ;
00284 ld_mean_rhs=NULL;
00285
00286 delete[] ld_mean_lhs ;
00287 ld_mean_lhs=NULL;
00288
00289 CKernel::cleanup();
00290 }
00291
00292 bool CSalzbergWordKernel::load_init(FILE* src)
00293 {
00294 return false;
00295 }
00296
00297 bool CSalzbergWordKernel::save_init(FILE* dest)
00298 {
00299 return false;
00300 }
00301
00302
00303
00304 DREAL CSalzbergWordKernel::compute(INT idx_a, INT idx_b)
00305 {
00306 INT alen, blen;
00307 WORD* avec=((CStringFeatures<WORD>*) lhs)->get_feature_vector(idx_a, alen);
00308 WORD* bvec=((CStringFeatures<WORD>*) rhs)->get_feature_vector(idx_b, blen);
00309
00310 ASSERT(alen==blen);
00311
00312 DREAL result = sum_m2_s2 ;
00313
00314 for (INT i=0; i<alen; i++)
00315 {
00316 if (avec[i]==bvec[i])
00317 {
00318 INT a_idx = compute_index(i, avec[i]) ;
00319
00320 DREAL theta_p = 1/estimate->log_derivative_pos_obsolete(avec[i], i) ;
00321 DREAL theta_n = 1/estimate->log_derivative_neg_obsolete(avec[i], i) ;
00322 DREAL value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ;
00323
00324 result += value*value/variance[a_idx] ;
00325 }
00326 }
00327 result += ld_mean_lhs[idx_a] + ld_mean_rhs[idx_b] ;
00328
00329
00330 if (initialized)
00331 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ;
00332
00333
00334 return result;
00335 }