00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/config.h"
00012
00013 #include <stdio.h>
00014 #include <string.h>
00015
00016 #include "lib/io.h"
00017
00018 #include "lib/matlab.h"
00019
00020 #include "structure/Plif.h"
00021
00022
00023
00024 CPlif::CPlif(INT l)
00025 : CPlifBase()
00026 {
00027 limits=NULL;
00028 penalties=NULL;
00029 cum_derivatives=NULL;
00030 id=-1;
00031 transform=T_LINEAR;
00032 name=NULL;
00033 max_value=0;
00034 min_value=0;
00035 cache=NULL;
00036 use_svm=0;
00037 use_cache=false;
00038 len=0;
00039 do_calc = true;
00040 if (l>0)
00041 set_plif_length(l);
00042 }
00043
00044 CPlif::~CPlif()
00045 {
00046 delete[] limits;
00047 delete[] penalties;
00048 delete[] name;
00049 delete[] cache;
00050 delete[] cum_derivatives;
00051 }
00052
00053 bool CPlif::set_transform_type(const char *type_str)
00054 {
00055 delete[] cache ;
00056 cache=NULL ;
00057
00058 if (strcmp(type_str, "linear")==0)
00059 transform = T_LINEAR ;
00060 else if (strcmp(type_str, "")==0)
00061 transform = T_LINEAR ;
00062 else if (strcmp(type_str, "log")==0)
00063 transform = T_LOG ;
00064 else if (strcmp(type_str, "log(+1)")==0)
00065 transform = T_LOG_PLUS1 ;
00066 else if (strcmp(type_str, "log(+3)")==0)
00067 transform = T_LOG_PLUS3 ;
00068 else if (strcmp(type_str, "(+3)")==0)
00069 transform = T_LINEAR_PLUS3 ;
00070 else
00071 {
00072 SG_ERROR( "unknown transform type (%s)\n", type_str) ;
00073 return false ;
00074 }
00075 return true ;
00076 }
00077 void CPlif::init_penalty_struct_cache()
00078 {
00079 if (!use_cache)
00080 return ;
00081 if (cache || use_svm)
00082 return ;
00083 if (max_value<=0)
00084 return ;
00085
00086 DREAL* local_cache=new DREAL[ ((INT) max_value) + 2] ;
00087
00088 if (local_cache)
00089 {
00090 for (INT i=0; i<=max_value; i++)
00091 {
00092 if (i<min_value)
00093 local_cache[i] = -CMath::INFTY ;
00094 else
00095 local_cache[i] = lookup_penalty(i, NULL) ;
00096 }
00097 }
00098 this->cache=local_cache ;
00099 }
00100
00101
00102 void CPlif::set_name(char *p_name)
00103 {
00104 delete[] name ;
00105 name=new char[strlen(p_name)+1] ;
00106 strcpy(name,p_name) ;
00107 }
00108
00109 #ifdef HAVE_MATLAB
00110 CPlif** read_penalty_struct_from_cell(const mxArray * mx_penalty_info, INT P)
00111 {
00112
00113
00114
00115 CPlif** PEN = new CPlif*[P] ;
00116 for (INT i=0; i<P; i++)
00117 PEN[i]=new CPlif() ;
00118
00119 for (INT i=0; i<P; i++)
00120 {
00121
00122
00123 const mxArray* mx_elem = mxGetCell(mx_penalty_info, i) ;
00124 if (mx_elem==NULL || !mxIsStruct(mx_elem))
00125 {
00126 SG_SERROR("empty cell element\n") ;
00127 delete[] PEN ;
00128 return NULL ;
00129 }
00130 const mxArray* mx_id_field = mxGetField(mx_elem, 0, "id") ;
00131 if (mx_id_field==NULL || !mxIsNumeric(mx_id_field) ||
00132 mxGetN(mx_id_field)!=1 || mxGetM(mx_id_field)!=1)
00133 {
00134 SG_SERROR( "missing id field\n") ;
00135 delete[] PEN;
00136 return NULL ;
00137 }
00138 const mxArray* mx_limits_field = mxGetField(mx_elem, 0, "limits") ;
00139 if (mx_limits_field==NULL || !mxIsNumeric(mx_limits_field) ||
00140 mxGetM(mx_limits_field)!=1)
00141 {
00142 SG_SERROR( "missing limits field\n") ;
00143 delete[] PEN ;
00144 return NULL ;
00145 }
00146 INT len = mxGetN(mx_limits_field) ;
00147
00148 const mxArray* mx_penalties_field = mxGetField(mx_elem, 0, "penalties") ;
00149 if (mx_penalties_field==NULL || !mxIsNumeric(mx_penalties_field) ||
00150 mxGetM(mx_penalties_field)!=1 || ((INT) mxGetN(mx_penalties_field))!=len)
00151 {
00152 SG_SERROR( "missing penalties field (%i)\n", i) ;
00153 delete[] PEN ;
00154 return NULL ;
00155 }
00156 const mxArray* mx_transform_field = mxGetField(mx_elem, 0, "transform") ;
00157 if (mx_transform_field==NULL || !mxIsChar(mx_transform_field))
00158 {
00159 SG_SERROR( "missing transform field\n") ;
00160 delete[] PEN;
00161 return NULL ;
00162 }
00163 const mxArray* mx_name_field = mxGetField(mx_elem, 0, "name") ;
00164 if (mx_name_field==NULL || !mxIsChar(mx_name_field))
00165 {
00166 SG_SERROR( "missing name field\n") ;
00167 delete[] PEN;
00168 return NULL ;
00169 }
00170 const mxArray* mx_max_value_field = mxGetField(mx_elem, 0, "max_value") ;
00171 if (mx_max_value_field==NULL || !mxIsNumeric(mx_max_value_field) ||
00172 mxGetM(mx_max_value_field)!=1 || mxGetN(mx_max_value_field)!=1)
00173 {
00174 SG_SERROR( "missing max_value field\n") ;
00175 delete[] PEN;
00176 return NULL ;
00177 }
00178 const mxArray* mx_min_value_field = mxGetField(mx_elem, 0, "min_value") ;
00179 if (mx_min_value_field==NULL || !mxIsNumeric(mx_min_value_field) ||
00180 mxGetM(mx_min_value_field)!=1 || mxGetN(mx_min_value_field)!=1)
00181 {
00182 SG_SERROR( "missing min_value field\n") ;
00183 delete[] PEN;
00184 return NULL ;
00185 }
00186 const mxArray* mx_use_svm_field = mxGetField(mx_elem, 0, "use_svm") ;
00187 if (mx_use_svm_field==NULL || !mxIsNumeric(mx_use_svm_field) ||
00188 mxGetM(mx_use_svm_field)!=1 || mxGetN(mx_use_svm_field)!=1)
00189 {
00190 SG_SERROR( "missing use_svm field\n") ;
00191 delete[] PEN;
00192 return NULL ;
00193 }
00194 INT use_svm = (INT) mxGetScalar(mx_use_svm_field) ;
00195
00196 const mxArray* mx_use_cache_field = mxGetField(mx_elem, 0, "use_cache") ;
00197 if (mx_use_cache_field==NULL || !mxIsNumeric(mx_use_cache_field) ||
00198 mxGetM(mx_use_cache_field)!=1 || mxGetN(mx_use_cache_field)!=1)
00199 {
00200 SG_SERROR( "missing use_cache field\n") ;
00201 delete[] PEN;
00202 return NULL ;
00203 }
00204 INT use_cache = (INT) mxGetScalar(mx_use_cache_field) ;
00205
00206 INT id = (INT) mxGetScalar(mx_id_field)-1 ;
00207 if (i<0 || i>P-1)
00208 {
00209 SG_SERROR( "id out of range\n") ;
00210 delete[] PEN;
00211 return NULL ;
00212 }
00213 INT max_value = (INT) mxGetScalar(mx_max_value_field) ;
00214 if (max_value<-1024*1024*100 || max_value>1024*1024*100)
00215 {
00216 SG_SERROR( "max_value out of range\n") ;
00217 delete[] PEN;
00218 return NULL ;
00219 }
00220 PEN[id]->set_max_value(max_value) ;
00221
00222 INT min_value = (INT) mxGetScalar(mx_min_value_field) ;
00223 if (min_value<-1024*1024*100 || min_value>1024*1024*100)
00224 {
00225 SG_SERROR( "min_value out of range\n") ;
00226 delete[] PEN;
00227 return NULL ;
00228 }
00229 PEN[id]->set_min_value(min_value) ;
00230
00231
00232 if (PEN[id]->get_id()!=-1)
00233 {
00234 SG_SERROR( "penalty id already used\n") ;
00235 delete[] PEN;
00236 return NULL ;
00237 }
00238 PEN[id]->set_id(id) ;
00239
00240 PEN[id]->set_use_svm(use_svm) ;
00241 PEN[id]->set_use_cache(use_cache) ;
00242
00243 double * limits = mxGetPr(mx_limits_field) ;
00244 double * penalties = mxGetPr(mx_penalties_field) ;
00245 PEN[id]->set_plif(len, limits, penalties) ;
00246
00247 char *transform_str = mxArrayToString(mx_transform_field) ;
00248 char *name_str = mxArrayToString(mx_name_field) ;
00249
00250 if (!PEN[id]->set_transform_type(transform_str))
00251 {
00252 SG_SERROR( "transform type not recognized ('%s')\n", transform_str) ;
00253 delete[] PEN;
00254 mxFree(transform_str) ;
00255 return NULL ;
00256 }
00257
00258 PEN[id]->set_name(name_str) ;
00259 PEN[id]->init_penalty_struct_cache() ;
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269 mxFree(transform_str) ;
00270 mxFree(name_str) ;
00271 }
00272 return PEN ;
00273 }
00274 #endif
00275
00276 void delete_penalty_struct(CPlif** PEN, INT P)
00277 {
00278 for (INT i=0; i<P; i++)
00279 delete PEN[i] ;
00280 delete[] PEN ;
00281 }
00282
00283 DREAL CPlif::lookup_penalty_svm(DREAL p_value, DREAL *d_values) const
00284 {
00285 ASSERT(use_svm>0);
00286 DREAL d_value=d_values[use_svm-1] ;
00287 #ifdef PLIF_DEBUG
00288 SG_PRINT("%s.lookup_penalty_svm(%f)\n", get_name(), d_value) ;
00289 #endif
00290
00291 if (!do_calc)
00292 return d_value;
00293 switch (transform)
00294 {
00295 case T_LINEAR:
00296 break ;
00297 case T_LOG:
00298 d_value = log(d_value) ;
00299 break ;
00300 case T_LOG_PLUS1:
00301 d_value = log(d_value+1) ;
00302 break ;
00303 case T_LOG_PLUS3:
00304 d_value = log(d_value+3) ;
00305 break ;
00306 case T_LINEAR_PLUS3:
00307 d_value = d_value+3 ;
00308 break ;
00309 default:
00310 SG_ERROR("unknown transform\n");
00311 break ;
00312 }
00313
00314 INT idx = 0 ;
00315 DREAL ret ;
00316 for (INT i=0; i<len; i++)
00317 if (limits[i]<=d_value)
00318 idx++ ;
00319 else
00320 break ;
00321
00322 #ifdef PLIF_DEBUG
00323 SG_PRINT(" -> idx = %i ", idx) ;
00324 #endif
00325
00326 if (idx==0)
00327 ret=penalties[0] ;
00328 else if (idx==len)
00329 ret=penalties[len-1] ;
00330 else
00331 {
00332 ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]*
00333 (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ;
00334 #ifdef PLIF_DEBUG
00335 SG_PRINT(" -> (%1.3f*%1.3f, %1.3f*%1.3f)", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ;
00336 #endif
00337 }
00338 #ifdef PLIF_DEBUG
00339 SG_PRINT(" -> ret=%1.3f\n", ret) ;
00340 #endif
00341
00342 return ret ;
00343 }
00344
00345 DREAL CPlif::lookup_penalty(INT p_value, DREAL* svm_values) const
00346 {
00347 if (use_svm)
00348 return lookup_penalty_svm(p_value, svm_values) ;
00349
00350 if ((p_value<min_value) || (p_value>max_value))
00351 return -CMath::INFTY ;
00352 if (!do_calc)
00353 return p_value;
00354 if (cache!=NULL && (p_value>=0) && (p_value<=max_value))
00355 {
00356 DREAL ret=cache[p_value] ;
00357 return ret ;
00358 }
00359 return lookup_penalty((DREAL) p_value, svm_values) ;
00360 }
00361
00362 DREAL CPlif::lookup_penalty(DREAL p_value, DREAL* svm_values) const
00363 {
00364 if (use_svm)
00365 return lookup_penalty_svm(p_value, svm_values) ;
00366
00367 #ifdef PLIF_DEBUG
00368 SG_PRINT("%s.lookup_penalty(%f)\n", get_name(), p_value) ;
00369 #endif
00370
00371
00372 if ((p_value<min_value) || (p_value>max_value))
00373 return -CMath::INFTY ;
00374
00375 if (!do_calc)
00376 return p_value;
00377
00378 DREAL d_value = (DREAL) p_value ;
00379 switch (transform)
00380 {
00381 case T_LINEAR:
00382 break ;
00383 case T_LOG:
00384 d_value = log(d_value) ;
00385 break ;
00386 case T_LOG_PLUS1:
00387 d_value = log(d_value+1) ;
00388 break ;
00389 case T_LOG_PLUS3:
00390 d_value = log(d_value+3) ;
00391 break ;
00392 case T_LINEAR_PLUS3:
00393 d_value = d_value+3 ;
00394 break ;
00395 default:
00396 SG_ERROR( "unknown transform\n") ;
00397 break ;
00398 }
00399
00400 #ifdef PLIF_DEBUG
00401 SG_PRINT(" -> value = %1.4f ", d_value) ;
00402 #endif
00403
00404 INT idx = 0 ;
00405 DREAL ret ;
00406 for (INT i=0; i<len; i++)
00407 if (limits[i]<=d_value)
00408 idx++ ;
00409 else
00410 break ;
00411
00412 #ifdef PLIF_DEBUG
00413 SG_PRINT(" -> idx = %i ", idx) ;
00414 #endif
00415
00416 if (idx==0)
00417 ret=penalties[0] ;
00418 else if (idx==len)
00419 ret=penalties[len-1] ;
00420 else
00421 {
00422 ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]*
00423 (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ;
00424 #ifdef PLIF_DEBUG
00425 SG_PRINT(" -> (%1.3f*%1.3f, %1.3f*%1.3f) ", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ;
00426 #endif
00427 }
00428
00429
00430 #ifdef PLIF_DEBUG
00431 SG_PRINT(" -> ret=%1.3f\n", ret) ;
00432 #endif
00433
00434 return ret ;
00435 }
00436
00437 void CPlif::penalty_clear_derivative()
00438 {
00439 for (INT i=0; i<len; i++)
00440 cum_derivatives[i]=0.0 ;
00441 }
00442
00443 void CPlif::penalty_add_derivative(DREAL p_value, DREAL* svm_values)
00444 {
00445 if (use_svm)
00446 {
00447 penalty_add_derivative_svm(p_value, svm_values) ;
00448 return ;
00449 }
00450
00451 if ((p_value<min_value) || (p_value>max_value))
00452 {
00453 return ;
00454 }
00455 DREAL d_value = (DREAL) p_value ;
00456 switch (transform)
00457 {
00458 case T_LINEAR:
00459 break ;
00460 case T_LOG:
00461 d_value = log(d_value) ;
00462 break ;
00463 case T_LOG_PLUS1:
00464 d_value = log(d_value+1) ;
00465 break ;
00466 case T_LOG_PLUS3:
00467 d_value = log(d_value+3) ;
00468 break ;
00469 case T_LINEAR_PLUS3:
00470 d_value = d_value+3 ;
00471 break ;
00472 default:
00473 SG_ERROR( "unknown transform\n") ;
00474 break ;
00475 }
00476
00477 INT idx = 0 ;
00478 for (INT i=0; i<len; i++)
00479 if (limits[i]<=d_value)
00480 idx++ ;
00481 else
00482 break ;
00483
00484 if (idx==0)
00485 cum_derivatives[0]+=1 ;
00486 else if (idx==len)
00487 cum_derivatives[len-1]+=1 ;
00488 else
00489 {
00490 cum_derivatives[idx]+=(d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ;
00491 cum_derivatives[idx-1]+=(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ;
00492 }
00493 }
00494
00495 void CPlif::penalty_add_derivative_svm(DREAL p_value, DREAL *d_values)
00496 {
00497 ASSERT(use_svm>0);
00498 DREAL d_value=d_values[use_svm-1] ;
00499
00500 if (d_value<-1e+20)
00501 return;
00502
00503 switch (transform)
00504 {
00505 case T_LINEAR:
00506 break ;
00507 case T_LOG:
00508 d_value = log(d_value) ;
00509 break ;
00510 case T_LOG_PLUS1:
00511 d_value = log(d_value+1) ;
00512 break ;
00513 case T_LOG_PLUS3:
00514 d_value = log(d_value+3) ;
00515 break ;
00516 case T_LINEAR_PLUS3:
00517 d_value = d_value+3 ;
00518 break ;
00519 default:
00520 SG_ERROR( "unknown transform\n") ;
00521 break ;
00522 }
00523
00524 INT idx = 0 ;
00525 for (INT i=0; i<len; i++)
00526 if (limits[i]<=d_value)
00527 idx++ ;
00528 else
00529 break ;
00530
00531 if (idx==0)
00532 cum_derivatives[0]+=1 ;
00533 else if (idx==len)
00534 cum_derivatives[len-1]+=1 ;
00535 else
00536 {
00537 cum_derivatives[idx]+=(d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ;
00538 cum_derivatives[idx-1]+=(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ;
00539 }
00540 }
00541 void CPlif::get_used_svms(INT* num_svms, INT* svm_ids)
00542 {
00543 if (use_svm)
00544 {
00545 svm_ids[(*num_svms)] = use_svm;
00546 (*num_svms)++;
00547 }
00548 SG_PRINT("->use_svm:%i plif_id:%i name:%s trans_type:%s ",use_svm, get_id(), get_name(), get_transform_type());
00549 }
00550 bool CPlif::get_do_calc()
00551 {
00552 return do_calc;
00553 }
00554 void CPlif::set_do_calc(bool b)
00555 {
00556 do_calc = b;;
00557 }
00558
00559
00560