00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/config.h"
00012
00013 #include <stdio.h>
00014 #include <string.h>
00015
00016 #include "lib/io.h"
00017
00018 #include "lib/matlab.h"
00019
00020 #include "structure/Plif.h"
00021
00022
00023
00024 CPlif::CPlif(int32_t l)
00025 : CPlifBase()
00026 {
00027 limits=NULL;
00028 penalties=NULL;
00029 cum_derivatives=NULL;
00030 id=-1;
00031 transform=T_LINEAR;
00032 name=NULL;
00033 max_value=0;
00034 min_value=0;
00035 cache=NULL;
00036 use_svm=0;
00037 use_cache=false;
00038 len=0;
00039 do_calc = true;
00040 if (l>0)
00041 set_plif_length(l);
00042 }
00043
00044 CPlif::~CPlif()
00045 {
00046 delete[] limits;
00047 delete[] penalties;
00048 delete[] name;
00049 delete[] cache;
00050 delete[] cum_derivatives;
00051 }
00052
00053 bool CPlif::set_transform_type(const char *type_str)
00054 {
00055 delete[] cache ;
00056 cache=NULL ;
00057
00058 if (strcmp(type_str, "linear")==0)
00059 transform = T_LINEAR ;
00060 else if (strcmp(type_str, "")==0)
00061 transform = T_LINEAR ;
00062 else if (strcmp(type_str, "log")==0)
00063 transform = T_LOG ;
00064 else if (strcmp(type_str, "log(+1)")==0)
00065 transform = T_LOG_PLUS1 ;
00066 else if (strcmp(type_str, "log(+3)")==0)
00067 transform = T_LOG_PLUS3 ;
00068 else if (strcmp(type_str, "(+3)")==0)
00069 transform = T_LINEAR_PLUS3 ;
00070 else
00071 {
00072 SG_ERROR( "unknown transform type (%s)\n", type_str) ;
00073 return false ;
00074 }
00075 return true ;
00076 }
00077
00078 void CPlif::init_penalty_struct_cache()
00079 {
00080 if (!use_cache)
00081 return ;
00082 if (cache || use_svm)
00083 return ;
00084 if (max_value<=0)
00085 return ;
00086
00087 float64_t* local_cache=new float64_t[ ((int32_t) max_value) + 2] ;
00088
00089 if (local_cache)
00090 {
00091 for (int32_t i=0; i<=max_value; i++)
00092 {
00093 if (i<min_value)
00094 local_cache[i] = -CMath::INFTY ;
00095 else
00096 local_cache[i] = lookup_penalty(i, NULL) ;
00097 }
00098 }
00099 this->cache=local_cache ;
00100 }
00101
00102 void CPlif::set_name(char *p_name)
00103 {
00104 delete[] name ;
00105 name=new char[strlen(p_name)+1] ;
00106 strcpy(name,p_name) ;
00107 }
00108
00109 #ifdef HAVE_MATLAB
00110 CPlif** read_penalty_struct_from_cell(
00111 const mxArray * mx_penalty_info, int32_t P)
00112 {
00113
00114
00115
00116 CPlif** PEN = new CPlif*[P] ;
00117 for (int32_t i=0; i<P; i++)
00118 PEN[i]=new CPlif() ;
00119
00120 for (int32_t i=0; i<P; i++)
00121 {
00122
00123
00124 const mxArray* mx_elem = mxGetCell(mx_penalty_info, i) ;
00125 if (mx_elem==NULL || !mxIsStruct(mx_elem))
00126 {
00127 SG_SERROR("empty cell element\n") ;
00128 delete[] PEN ;
00129 return NULL ;
00130 }
00131 const mxArray* mx_id_field = mxGetField(mx_elem, 0, "id") ;
00132 if (mx_id_field==NULL || !mxIsNumeric(mx_id_field) ||
00133 mxGetN(mx_id_field)!=1 || mxGetM(mx_id_field)!=1)
00134 {
00135 SG_SERROR( "missing id field\n") ;
00136 delete[] PEN;
00137 return NULL ;
00138 }
00139 const mxArray* mx_limits_field = mxGetField(mx_elem, 0, "limits") ;
00140 if (mx_limits_field==NULL || !mxIsNumeric(mx_limits_field) ||
00141 mxGetM(mx_limits_field)!=1)
00142 {
00143 SG_SERROR( "missing limits field\n") ;
00144 delete[] PEN ;
00145 return NULL ;
00146 }
00147 int32_t len = mxGetN(mx_limits_field) ;
00148
00149 const mxArray* mx_penalties_field = mxGetField(mx_elem, 0, "penalties") ;
00150 if (mx_penalties_field==NULL || !mxIsNumeric(mx_penalties_field) ||
00151 mxGetM(mx_penalties_field)!=1 || ((int32_t) mxGetN(mx_penalties_field))!=len)
00152 {
00153 SG_SERROR( "missing penalties field (%i)\n", i) ;
00154 delete[] PEN ;
00155 return NULL ;
00156 }
00157 const mxArray* mx_transform_field = mxGetField(mx_elem, 0, "transform") ;
00158 if (mx_transform_field==NULL || !mxIsChar(mx_transform_field))
00159 {
00160 SG_SERROR( "missing transform field\n") ;
00161 delete[] PEN;
00162 return NULL ;
00163 }
00164 const mxArray* mx_name_field = mxGetField(mx_elem, 0, "name") ;
00165 if (mx_name_field==NULL || !mxIsChar(mx_name_field))
00166 {
00167 SG_SERROR( "missing name field\n") ;
00168 delete[] PEN;
00169 return NULL ;
00170 }
00171 const mxArray* mx_max_value_field = mxGetField(mx_elem, 0, "max_value") ;
00172 if (mx_max_value_field==NULL || !mxIsNumeric(mx_max_value_field) ||
00173 mxGetM(mx_max_value_field)!=1 || mxGetN(mx_max_value_field)!=1)
00174 {
00175 SG_SERROR( "missing max_value field\n") ;
00176 delete[] PEN;
00177 return NULL ;
00178 }
00179 const mxArray* mx_min_value_field = mxGetField(mx_elem, 0, "min_value") ;
00180 if (mx_min_value_field==NULL || !mxIsNumeric(mx_min_value_field) ||
00181 mxGetM(mx_min_value_field)!=1 || mxGetN(mx_min_value_field)!=1)
00182 {
00183 SG_SERROR( "missing min_value field\n") ;
00184 delete[] PEN;
00185 return NULL ;
00186 }
00187 const mxArray* mx_use_svm_field = mxGetField(mx_elem, 0, "use_svm") ;
00188 if (mx_use_svm_field==NULL || !mxIsNumeric(mx_use_svm_field) ||
00189 mxGetM(mx_use_svm_field)!=1 || mxGetN(mx_use_svm_field)!=1)
00190 {
00191 SG_SERROR( "missing use_svm field\n") ;
00192 delete[] PEN;
00193 return NULL ;
00194 }
00195 int32_t use_svm = (int32_t) mxGetScalar(mx_use_svm_field) ;
00196
00197 const mxArray* mx_use_cache_field = mxGetField(mx_elem, 0, "use_cache") ;
00198 if (mx_use_cache_field==NULL || !mxIsNumeric(mx_use_cache_field) ||
00199 mxGetM(mx_use_cache_field)!=1 || mxGetN(mx_use_cache_field)!=1)
00200 {
00201 SG_SERROR( "missing use_cache field\n") ;
00202 delete[] PEN;
00203 return NULL ;
00204 }
00205 int32_t use_cache = (int32_t) mxGetScalar(mx_use_cache_field) ;
00206
00207 int32_t id = (int32_t) mxGetScalar(mx_id_field)-1 ;
00208 if (i<0 || i>P-1)
00209 {
00210 SG_SERROR( "id out of range\n") ;
00211 delete[] PEN;
00212 return NULL ;
00213 }
00214 int32_t max_value = (int32_t) mxGetScalar(mx_max_value_field) ;
00215 if (max_value<-1024*1024*100 || max_value>1024*1024*100)
00216 {
00217 SG_SERROR( "max_value out of range\n") ;
00218 delete[] PEN;
00219 return NULL ;
00220 }
00221 PEN[id]->set_max_value(max_value) ;
00222
00223 int32_t min_value = (int32_t) mxGetScalar(mx_min_value_field) ;
00224 if (min_value<-1024*1024*100 || min_value>1024*1024*100)
00225 {
00226 SG_SERROR( "min_value out of range\n") ;
00227 delete[] PEN;
00228 return NULL ;
00229 }
00230 PEN[id]->set_min_value(min_value) ;
00231
00232
00233 if (PEN[id]->get_id()!=-1)
00234 {
00235 SG_SERROR( "penalty id already used\n") ;
00236 delete[] PEN;
00237 return NULL ;
00238 }
00239 PEN[id]->set_id(id) ;
00240
00241 PEN[id]->set_use_svm(use_svm) ;
00242 PEN[id]->set_use_cache(use_cache) ;
00243
00244 double * limits = mxGetPr(mx_limits_field) ;
00245 double * penalties = mxGetPr(mx_penalties_field) ;
00246 PEN[id]->set_plif(len, limits, penalties) ;
00247
00248 char *transform_str = mxArrayToString(mx_transform_field) ;
00249 char *name_str = mxArrayToString(mx_name_field) ;
00250
00251 if (!PEN[id]->set_transform_type(transform_str))
00252 {
00253 SG_SERROR( "transform type not recognized ('%s')\n", transform_str) ;
00254 delete[] PEN;
00255 mxFree(transform_str) ;
00256 return NULL ;
00257 }
00258
00259 PEN[id]->set_name(name_str) ;
00260 PEN[id]->init_penalty_struct_cache() ;
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270 mxFree(transform_str) ;
00271 mxFree(name_str) ;
00272 }
00273 return PEN ;
00274 }
00275 #endif
00276
00277 void delete_penalty_struct(CPlif** PEN, int32_t P)
00278 {
00279 for (int32_t i=0; i<P; i++)
00280 delete PEN[i] ;
00281 delete[] PEN ;
00282 }
00283
00284 float64_t CPlif::lookup_penalty_svm(
00285 float64_t p_value, float64_t *d_values) const
00286 {
00287 ASSERT(use_svm>0);
00288 float64_t d_value=d_values[use_svm-1] ;
00289 #ifdef PLIF_DEBUG
00290 SG_PRINT("%s.lookup_penalty_svm(%f)\n", get_name(), d_value) ;
00291 #endif
00292
00293 if (!do_calc)
00294 return d_value;
00295 switch (transform)
00296 {
00297 case T_LINEAR:
00298 break ;
00299 case T_LOG:
00300 d_value = log(d_value) ;
00301 break ;
00302 case T_LOG_PLUS1:
00303 d_value = log(d_value+1) ;
00304 break ;
00305 case T_LOG_PLUS3:
00306 d_value = log(d_value+3) ;
00307 break ;
00308 case T_LINEAR_PLUS3:
00309 d_value = d_value+3 ;
00310 break ;
00311 default:
00312 SG_ERROR("unknown transform\n");
00313 break ;
00314 }
00315
00316 int32_t idx = 0 ;
00317 float64_t ret ;
00318 for (int32_t i=0; i<len; i++)
00319 if (limits[i]<=d_value)
00320 idx++ ;
00321 else
00322 break ;
00323
00324 #ifdef PLIF_DEBUG
00325 SG_PRINT(" -> idx = %i ", idx) ;
00326 #endif
00327
00328 if (idx==0)
00329 ret=penalties[0] ;
00330 else if (idx==len)
00331 ret=penalties[len-1] ;
00332 else
00333 {
00334 ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]*
00335 (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ;
00336 #ifdef PLIF_DEBUG
00337 SG_PRINT(" -> (%1.3f*%1.3f, %1.3f*%1.3f)", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ;
00338 #endif
00339 }
00340 #ifdef PLIF_DEBUG
00341 SG_PRINT(" -> ret=%1.3f\n", ret) ;
00342 #endif
00343
00344 return ret ;
00345 }
00346
00347 float64_t CPlif::lookup_penalty(int32_t p_value, float64_t* svm_values) const
00348 {
00349 if (use_svm)
00350 return lookup_penalty_svm(p_value, svm_values) ;
00351
00352 if ((p_value<min_value) || (p_value>max_value))
00353 return -CMath::INFTY ;
00354 if (!do_calc)
00355 return p_value;
00356 if (cache!=NULL && (p_value>=0) && (p_value<=max_value))
00357 {
00358 float64_t ret=cache[p_value] ;
00359 return ret ;
00360 }
00361 return lookup_penalty((float64_t) p_value, svm_values) ;
00362 }
00363
00364 float64_t CPlif::lookup_penalty(float64_t p_value, float64_t* svm_values) const
00365 {
00366 if (use_svm)
00367 return lookup_penalty_svm(p_value, svm_values) ;
00368
00369 #ifdef PLIF_DEBUG
00370 SG_PRINT("%s.lookup_penalty(%f)\n", get_name(), p_value) ;
00371 #endif
00372
00373
00374 if ((p_value<min_value) || (p_value>max_value))
00375 return -CMath::INFTY ;
00376
00377 if (!do_calc)
00378 return p_value;
00379
00380 float64_t d_value = (float64_t) p_value ;
00381 switch (transform)
00382 {
00383 case T_LINEAR:
00384 break ;
00385 case T_LOG:
00386 d_value = log(d_value) ;
00387 break ;
00388 case T_LOG_PLUS1:
00389 d_value = log(d_value+1) ;
00390 break ;
00391 case T_LOG_PLUS3:
00392 d_value = log(d_value+3) ;
00393 break ;
00394 case T_LINEAR_PLUS3:
00395 d_value = d_value+3 ;
00396 break ;
00397 default:
00398 SG_ERROR( "unknown transform\n") ;
00399 break ;
00400 }
00401
00402 #ifdef PLIF_DEBUG
00403 SG_PRINT(" -> value = %1.4f ", d_value) ;
00404 #endif
00405
00406 int32_t idx = 0 ;
00407 float64_t ret ;
00408 for (int32_t i=0; i<len; i++)
00409 if (limits[i]<=d_value)
00410 idx++ ;
00411 else
00412 break ;
00413
00414 #ifdef PLIF_DEBUG
00415 SG_PRINT(" -> idx = %i ", idx) ;
00416 #endif
00417
00418 if (idx==0)
00419 ret=penalties[0] ;
00420 else if (idx==len)
00421 ret=penalties[len-1] ;
00422 else
00423 {
00424 ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]*
00425 (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ;
00426 #ifdef PLIF_DEBUG
00427 SG_PRINT(" -> (%1.3f*%1.3f, %1.3f*%1.3f) ", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ;
00428 #endif
00429 }
00430
00431
00432 #ifdef PLIF_DEBUG
00433 SG_PRINT(" -> ret=%1.3f\n", ret) ;
00434 #endif
00435
00436 return ret ;
00437 }
00438
00439 void CPlif::penalty_clear_derivative()
00440 {
00441 for (int32_t i=0; i<len; i++)
00442 cum_derivatives[i]=0.0 ;
00443 }
00444
00445 void CPlif::penalty_add_derivative(float64_t p_value, float64_t* svm_values)
00446 {
00447 if (use_svm)
00448 {
00449 penalty_add_derivative_svm(p_value, svm_values) ;
00450 return ;
00451 }
00452
00453 if ((p_value<min_value) || (p_value>max_value))
00454 {
00455 return ;
00456 }
00457 float64_t d_value = (float64_t) p_value ;
00458 switch (transform)
00459 {
00460 case T_LINEAR:
00461 break ;
00462 case T_LOG:
00463 d_value = log(d_value) ;
00464 break ;
00465 case T_LOG_PLUS1:
00466 d_value = log(d_value+1) ;
00467 break ;
00468 case T_LOG_PLUS3:
00469 d_value = log(d_value+3) ;
00470 break ;
00471 case T_LINEAR_PLUS3:
00472 d_value = d_value+3 ;
00473 break ;
00474 default:
00475 SG_ERROR( "unknown transform\n") ;
00476 break ;
00477 }
00478
00479 int32_t idx = 0 ;
00480 for (int32_t i=0; i<len; i++)
00481 if (limits[i]<=d_value)
00482 idx++ ;
00483 else
00484 break ;
00485
00486 if (idx==0)
00487 cum_derivatives[0]+=1 ;
00488 else if (idx==len)
00489 cum_derivatives[len-1]+=1 ;
00490 else
00491 {
00492 cum_derivatives[idx]+=(d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ;
00493 cum_derivatives[idx-1]+=(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ;
00494 }
00495 }
00496
00497 void CPlif::penalty_add_derivative_svm(float64_t p_value, float64_t *d_values)
00498 {
00499 ASSERT(use_svm>0);
00500 float64_t d_value=d_values[use_svm-1] ;
00501
00502 if (d_value<-1e+20)
00503 return;
00504
00505 switch (transform)
00506 {
00507 case T_LINEAR:
00508 break ;
00509 case T_LOG:
00510 d_value = log(d_value) ;
00511 break ;
00512 case T_LOG_PLUS1:
00513 d_value = log(d_value+1) ;
00514 break ;
00515 case T_LOG_PLUS3:
00516 d_value = log(d_value+3) ;
00517 break ;
00518 case T_LINEAR_PLUS3:
00519 d_value = d_value+3 ;
00520 break ;
00521 default:
00522 SG_ERROR( "unknown transform\n") ;
00523 break ;
00524 }
00525
00526 int32_t idx = 0 ;
00527 for (int32_t i=0; i<len; i++)
00528 if (limits[i]<=d_value)
00529 idx++ ;
00530 else
00531 break ;
00532
00533 if (idx==0)
00534 cum_derivatives[0]+=1 ;
00535 else if (idx==len)
00536 cum_derivatives[len-1]+=1 ;
00537 else
00538 {
00539 cum_derivatives[idx]+=(d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ;
00540 cum_derivatives[idx-1]+=(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ;
00541 }
00542 }
00543
00544 void CPlif::get_used_svms(int32_t* num_svms, int32_t* svm_ids)
00545 {
00546 if (use_svm)
00547 {
00548 svm_ids[(*num_svms)] = use_svm;
00549 (*num_svms)++;
00550 }
00551 SG_PRINT("->use_svm:%i plif_id:%i name:%s trans_type:%s ",use_svm, get_id(), get_name(), get_transform_type());
00552 }
00553
00554 bool CPlif::get_do_calc()
00555 {
00556 return do_calc;
00557 }
00558
00559 void CPlif::set_do_calc(bool b)
00560 {
00561 do_calc = b;;
00562 }