00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "kernel/SalzbergWordStringKernel.h"
00014 #include "features/Features.h"
00015 #include "features/StringFeatures.h"
00016 #include "features/Labels.h"
00017 #include "classifier/PluginEstimate.h"
00018
00019 CSalzbergWordStringKernel::CSalzbergWordStringKernel(int32_t size, CPluginEstimate* pie, CLabels* labels)
00020 : CStringKernel<uint16_t>(size), estimate(pie), mean(NULL), variance(NULL),
00021 sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00022 ld_mean_lhs(NULL), ld_mean_rhs(NULL),
00023 num_params(0), num_symbols(0), sum_m2_s2(0), pos_prior(0.5),
00024 neg_prior(0.5), initialized(false)
00025 {
00026 if (labels)
00027 set_prior_probs_from_labels(labels);
00028 }
00029
00030 CSalzbergWordStringKernel::CSalzbergWordStringKernel(
00031 CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r,
00032 CPluginEstimate* pie, CLabels* labels)
00033 : CStringKernel<uint16_t>(10),estimate(pie), mean(NULL), variance(NULL),
00034 sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00035 ld_mean_lhs(NULL), ld_mean_rhs(NULL),
00036 num_params(0), num_symbols(0), sum_m2_s2(0), pos_prior(0.5),
00037 neg_prior(0.5), initialized(false)
00038 {
00039 if (labels)
00040 set_prior_probs_from_labels(labels);
00041
00042 init(l, r);
00043 }
00044
00045 CSalzbergWordStringKernel::~CSalzbergWordStringKernel()
00046 {
00047 cleanup();
00048 }
00049
00050 bool CSalzbergWordStringKernel::init(CFeatures* p_l, CFeatures* p_r)
00051 {
00052 CStringKernel<uint16_t>::init(p_l,p_r);
00053 CStringFeatures<uint16_t>* l=(CStringFeatures<uint16_t>*) p_l;
00054 ASSERT(l);
00055 CStringFeatures<uint16_t>* r=(CStringFeatures<uint16_t>*) p_r;
00056 ASSERT(r);
00057
00058 int32_t i;
00059 initialized=false;
00060
00061 if (sqrtdiag_lhs!=sqrtdiag_rhs)
00062 delete[] sqrtdiag_rhs;
00063 sqrtdiag_rhs=NULL;
00064 delete[] sqrtdiag_lhs;
00065 sqrtdiag_lhs=NULL;
00066 if (ld_mean_lhs!=ld_mean_rhs)
00067 delete[] ld_mean_rhs;
00068 ld_mean_rhs=NULL;
00069 delete[] ld_mean_lhs;
00070 ld_mean_lhs=NULL;
00071
00072 sqrtdiag_lhs=new float64_t[l->get_num_vectors()];
00073 ld_mean_lhs=new float64_t[l->get_num_vectors()];
00074
00075 for (i=0; i<l->get_num_vectors(); i++)
00076 sqrtdiag_lhs[i]=1;
00077
00078 if (l==r)
00079 {
00080 sqrtdiag_rhs=sqrtdiag_lhs;
00081 ld_mean_rhs=ld_mean_lhs;
00082 }
00083 else
00084 {
00085 sqrtdiag_rhs=new float64_t[r->get_num_vectors()];
00086 for (i=0; i<r->get_num_vectors(); i++)
00087 sqrtdiag_rhs[i]=1;
00088
00089 ld_mean_rhs=new float64_t[r->get_num_vectors()];
00090 }
00091
00092 float64_t* l_ld_mean_lhs=ld_mean_lhs;
00093 float64_t* l_ld_mean_rhs=ld_mean_rhs;
00094
00095
00096 if (!initialized)
00097 {
00098 int32_t num_vectors=l->get_num_vectors();
00099 num_symbols=(int32_t) l->get_num_symbols();
00100 int32_t llen=l->get_vector_length(0);
00101 int32_t rlen=r->get_vector_length(0);
00102 num_params=(int32_t) llen*l->get_num_symbols();
00103 int32_t num_params2=(int32_t) llen*l->get_num_symbols()+rlen*r->get_num_symbols();
00104 if ((!estimate) || (!estimate->check_models()))
00105 {
00106 SG_ERROR( "no estimate available\n");
00107 return false ;
00108 } ;
00109 if (num_params2!=estimate->get_num_params())
00110 {
00111 SG_ERROR( "number of parameters of estimate and feature representation do not match\n");
00112 return false ;
00113 } ;
00114
00115 delete[] variance;
00116 delete[] mean;
00117 mean=new float64_t[num_params];
00118 ASSERT(mean);
00119 variance=new float64_t[num_params];
00120 ASSERT(variance);
00121
00122 for (i=0; i<num_params; i++)
00123 {
00124 mean[i]=0;
00125 variance[i]=0;
00126 }
00127
00128
00129
00130 for (i=0; i<num_vectors; i++)
00131 {
00132 int32_t len;
00133 uint16_t* vec=l->get_feature_vector(i, len);
00134
00135 for (int32_t j=0; j<len; j++)
00136 {
00137 int32_t idx=compute_index(j, vec[j]);
00138 float64_t theta_p = 1/estimate->log_derivative_pos_obsolete(vec[j], j) ;
00139 float64_t theta_n = 1/estimate->log_derivative_neg_obsolete(vec[j], j) ;
00140 float64_t value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ;
00141
00142 mean[idx] += value/num_vectors ;
00143 }
00144 }
00145
00146
00147 for (i=0; i<num_vectors; i++)
00148 {
00149 int32_t len;
00150 uint16_t* vec=l->get_feature_vector(i, len);
00151
00152 for (int32_t j=0; j<len; j++)
00153 {
00154 for (int32_t k=0; k<4; k++)
00155 {
00156 int32_t idx=compute_index(j, k);
00157 if (k!=vec[j])
00158 variance[idx]+=mean[idx]*mean[idx]/num_vectors;
00159 else
00160 {
00161 float64_t theta_p = 1/estimate->log_derivative_pos_obsolete(vec[j], j) ;
00162 float64_t theta_n = 1/estimate->log_derivative_neg_obsolete(vec[j], j) ;
00163 float64_t value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ;
00164
00165 variance[idx] += CMath::sq(value-mean[idx])/num_vectors;
00166 }
00167 }
00168 }
00169 }
00170
00171
00172
00173 sum_m2_s2=0 ;
00174 for (i=0; i<num_params; i++)
00175 {
00176 if (variance[i]<1e-14)
00177 variance[i]=1 ;
00178
00179
00180 sum_m2_s2 += mean[i]*mean[i]/(variance[i]) ;
00181 } ;
00182 }
00183
00184
00185
00186
00187 for (i=0; i<l->get_num_vectors(); i++)
00188 {
00189 int32_t alen ;
00190 uint16_t* avec=l->get_feature_vector(i, alen);
00191 float64_t result=0 ;
00192 for (int32_t j=0; j<alen; j++)
00193 {
00194 int32_t a_idx = compute_index(j, avec[j]) ;
00195 float64_t theta_p = 1/estimate->log_derivative_pos_obsolete(avec[j], j) ;
00196 float64_t theta_n = 1/estimate->log_derivative_neg_obsolete(avec[j], j) ;
00197 float64_t value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ;
00198
00199 if (variance[a_idx]!=0)
00200 result-=value*mean[a_idx]/variance[a_idx];
00201 }
00202 ld_mean_lhs[i]=result ;
00203 }
00204
00205 if (ld_mean_lhs!=ld_mean_rhs)
00206 {
00207
00208
00209 for (i=0; i<r->get_num_vectors(); i++)
00210 {
00211 int32_t alen ;
00212 uint16_t* avec=r->get_feature_vector(i, alen);
00213 float64_t result=0 ;
00214 for (int32_t j=0; j<alen; j++)
00215 {
00216 int32_t a_idx = compute_index(j, avec[j]) ;
00217 float64_t theta_p=1/estimate->log_derivative_pos_obsolete(
00218 avec[j], j) ;
00219 float64_t theta_n=1/estimate->log_derivative_neg_obsolete(
00220 avec[j], j) ;
00221 float64_t value=(theta_p/(pos_prior*theta_p+neg_prior*theta_n));
00222
00223 result -= value*mean[a_idx]/variance[a_idx] ;
00224 }
00225 ld_mean_rhs[i]=result ;
00226 } ;
00227 } ;
00228
00229
00230
00231 this->lhs=l;
00232 this->rhs=l;
00233 ld_mean_lhs = l_ld_mean_lhs ;
00234 ld_mean_rhs = l_ld_mean_lhs ;
00235
00236
00237 for (i=0; i<lhs->get_num_vectors(); i++)
00238 {
00239 sqrtdiag_lhs[i]=sqrt(compute(i,i));
00240
00241
00242 if (sqrtdiag_lhs[i]==0)
00243 sqrtdiag_lhs[i]=1e-16;
00244 }
00245
00246
00247
00248 if (sqrtdiag_lhs!=sqrtdiag_rhs)
00249 {
00250 this->lhs=r;
00251 this->rhs=r;
00252 ld_mean_lhs = l_ld_mean_rhs ;
00253 ld_mean_rhs = l_ld_mean_rhs ;
00254
00255
00256 for (i=0; i<rhs->get_num_vectors(); i++)
00257 {
00258 sqrtdiag_rhs[i]=sqrt(compute(i,i));
00259
00260
00261 if (sqrtdiag_rhs[i]==0)
00262 sqrtdiag_rhs[i]=1e-16;
00263 }
00264 }
00265
00266 this->lhs=l;
00267 this->rhs=r;
00268 ld_mean_lhs = l_ld_mean_lhs ;
00269 ld_mean_rhs = l_ld_mean_rhs ;
00270
00271 initialized = true ;
00272 return init_normalizer();
00273 }
00274
00275 void CSalzbergWordStringKernel::cleanup()
00276 {
00277 delete[] variance;
00278 variance=NULL;
00279
00280 delete[] mean;
00281 mean=NULL;
00282
00283 if (sqrtdiag_lhs != sqrtdiag_rhs)
00284 delete[] sqrtdiag_rhs;
00285 sqrtdiag_rhs=NULL;
00286
00287 delete[] sqrtdiag_lhs;
00288 sqrtdiag_lhs=NULL;
00289
00290 if (ld_mean_lhs!=ld_mean_rhs)
00291 delete[] ld_mean_rhs ;
00292 ld_mean_rhs=NULL;
00293
00294 delete[] ld_mean_lhs ;
00295 ld_mean_lhs=NULL;
00296
00297 CKernel::cleanup();
00298 }
00299
00300 bool CSalzbergWordStringKernel::load_init(FILE* src)
00301 {
00302 return false;
00303 }
00304
00305 bool CSalzbergWordStringKernel::save_init(FILE* dest)
00306 {
00307 return false;
00308 }
00309
00310
00311
00312 float64_t CSalzbergWordStringKernel::compute(int32_t idx_a, int32_t idx_b)
00313 {
00314 int32_t alen, blen;
00315 uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen);
00316 uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen);
00317
00318 ASSERT(alen==blen);
00319
00320 float64_t result = sum_m2_s2 ;
00321
00322 for (int32_t i=0; i<alen; i++)
00323 {
00324 if (avec[i]==bvec[i])
00325 {
00326 int32_t a_idx = compute_index(i, avec[i]) ;
00327
00328 float64_t theta_p = 1/estimate->log_derivative_pos_obsolete(avec[i], i) ;
00329 float64_t theta_n = 1/estimate->log_derivative_neg_obsolete(avec[i], i) ;
00330 float64_t value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ;
00331
00332 result += value*value/variance[a_idx] ;
00333 }
00334 }
00335 result += ld_mean_lhs[idx_a] + ld_mean_rhs[idx_b] ;
00336
00337
00338 if (initialized)
00339 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ;
00340
00341 return result;
00342 }
00343
00344 void CSalzbergWordStringKernel::set_prior_probs_from_labels(CLabels* labels)
00345 {
00346 ASSERT(labels);
00347
00348 int32_t num_pos=0, num_neg=0;
00349 for (int32_t i=0; i<labels->get_num_labels(); i++)
00350 {
00351 if (labels->get_int_label(i)==1)
00352 num_pos++;
00353 if (labels->get_int_label(i)==-1)
00354 num_neg++;
00355 }
00356
00357 SG_INFO("priors: pos=%1.3f (%i) neg=%1.3f (%i)\n",
00358 (float64_t) num_pos/(num_pos+num_neg), num_pos,
00359 (float64_t) num_neg/(num_pos+num_neg), num_neg);
00360
00361 set_prior_probs(
00362 (float64_t)num_pos/(num_pos+num_neg),
00363 (float64_t)num_neg/(num_pos+num_neg));
00364 }