Plif.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Gunnar Raetsch
00008  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/config.h"
00012 
00013 #include <stdio.h>
00014 #include <string.h>
00015 
00016 #include "lib/io.h"
00017 
00018 #include "lib/matlab.h"
00019 
00020 #include "structure/Plif.h"
00021 
00022 //#define PLIF_DEBUG
00023 
00024 CPlif::CPlif(INT l)
00025 : CPlifBase()
00026 {
00027     limits=NULL;
00028     penalties=NULL;
00029     cum_derivatives=NULL;
00030     id=-1;
00031     transform=T_LINEAR;
00032     name=NULL;
00033     max_value=0;
00034     min_value=0;
00035     cache=NULL;
00036     use_svm=0;
00037     use_cache=false;
00038     len=0;
00039     do_calc = true;
00040     if (l>0)
00041         set_plif_length(l);
00042 }
00043 
00044 CPlif::~CPlif()
00045 {
00046     delete[] limits;
00047     delete[] penalties;
00048     delete[] name;
00049     delete[] cache;
00050     delete[] cum_derivatives;
00051 }
00052 
00053 bool CPlif::set_transform_type(const char *type_str)
00054 {
00055     delete[] cache ;
00056     cache=NULL ;
00057 
00058     if (strcmp(type_str, "linear")==0)
00059         transform = T_LINEAR ;
00060     else if (strcmp(type_str, "")==0)
00061         transform = T_LINEAR ;
00062     else if (strcmp(type_str, "log")==0)
00063         transform = T_LOG ;
00064     else if (strcmp(type_str, "log(+1)")==0)
00065         transform = T_LOG_PLUS1 ;
00066     else if (strcmp(type_str, "log(+3)")==0)
00067         transform = T_LOG_PLUS3 ;
00068     else if (strcmp(type_str, "(+3)")==0)
00069         transform = T_LINEAR_PLUS3 ;
00070     else
00071     {
00072         SG_ERROR( "unknown transform type (%s)\n", type_str) ;
00073         return false ;
00074     }
00075     return true ;
00076 }
00077 void CPlif::init_penalty_struct_cache()
00078 {
00079     if (!use_cache)
00080         return ;
00081     if (cache || use_svm)
00082         return ;
00083     if (max_value<=0)
00084         return ;
00085 
00086     DREAL* local_cache=new DREAL[ ((INT) max_value) + 2] ;
00087     
00088     if (local_cache)
00089     {
00090         for (INT i=0; i<=max_value; i++)
00091         {
00092             if (i<min_value)
00093                 local_cache[i] = -CMath::INFTY ;
00094             else
00095                 local_cache[i] = lookup_penalty(i, NULL) ;
00096         }
00097     }
00098     this->cache=local_cache ;
00099 }
00100 
00101     
00102 void CPlif::set_name(char *p_name) 
00103 {
00104     delete[] name ;
00105     name=new char[strlen(p_name)+1] ;
00106     strcpy(name,p_name) ;
00107 }
00108 
00109 #ifdef HAVE_MATLAB
00110 CPlif** read_penalty_struct_from_cell(const mxArray * mx_penalty_info, INT P)
00111 {
00112     //P = mxGetN(mx_penalty_info) ;
00113     //fprintf(stderr, "p=%i size=%i\n", P, P*sizeof(CPlif)) ;
00114     
00115     CPlif** PEN = new CPlif*[P] ;
00116     for (INT i=0; i<P; i++)
00117         PEN[i]=new CPlif() ;
00118     
00119     for (INT i=0; i<P; i++)
00120     {
00121         //fprintf(stderr, "i=%i/%i\n", i, P) ;
00122         
00123         const mxArray* mx_elem = mxGetCell(mx_penalty_info, i) ;
00124         if (mx_elem==NULL || !mxIsStruct(mx_elem))
00125         {
00126             SG_SERROR("empty cell element\n") ;
00127             delete[] PEN ;
00128             return NULL ;
00129         }
00130         const mxArray* mx_id_field = mxGetField(mx_elem, 0, "id") ;
00131         if (mx_id_field==NULL || !mxIsNumeric(mx_id_field) || 
00132             mxGetN(mx_id_field)!=1 || mxGetM(mx_id_field)!=1)
00133         {
00134             SG_SERROR( "missing id field\n") ;
00135             delete[] PEN;
00136             return NULL ;
00137         }
00138         const mxArray* mx_limits_field = mxGetField(mx_elem, 0, "limits") ;
00139         if (mx_limits_field==NULL || !mxIsNumeric(mx_limits_field) ||
00140             mxGetM(mx_limits_field)!=1)
00141         {
00142             SG_SERROR( "missing limits field\n") ;
00143             delete[] PEN ;
00144             return NULL ;
00145         }
00146         INT len = mxGetN(mx_limits_field) ;
00147         
00148         const mxArray* mx_penalties_field = mxGetField(mx_elem, 0, "penalties") ;
00149         if (mx_penalties_field==NULL || !mxIsNumeric(mx_penalties_field) ||
00150             mxGetM(mx_penalties_field)!=1 || ((INT) mxGetN(mx_penalties_field))!=len)
00151         {
00152             SG_SERROR( "missing penalties field (%i)\n", i) ;
00153             delete[] PEN ;
00154             return NULL ;
00155         }
00156         const mxArray* mx_transform_field = mxGetField(mx_elem, 0, "transform") ;
00157         if (mx_transform_field==NULL || !mxIsChar(mx_transform_field))
00158         {
00159             SG_SERROR( "missing transform field\n") ;
00160             delete[] PEN;
00161             return NULL ;
00162         }
00163         const mxArray* mx_name_field = mxGetField(mx_elem, 0, "name") ;
00164         if (mx_name_field==NULL || !mxIsChar(mx_name_field))
00165         {
00166             SG_SERROR( "missing name field\n") ;
00167             delete[] PEN;
00168             return NULL ;
00169         }
00170         const mxArray* mx_max_value_field = mxGetField(mx_elem, 0, "max_value") ;
00171         if (mx_max_value_field==NULL || !mxIsNumeric(mx_max_value_field) ||
00172             mxGetM(mx_max_value_field)!=1 || mxGetN(mx_max_value_field)!=1)
00173         {
00174             SG_SERROR( "missing max_value field\n") ;
00175             delete[] PEN;
00176             return NULL ;
00177         }
00178         const mxArray* mx_min_value_field = mxGetField(mx_elem, 0, "min_value") ;
00179         if (mx_min_value_field==NULL || !mxIsNumeric(mx_min_value_field) ||
00180             mxGetM(mx_min_value_field)!=1 || mxGetN(mx_min_value_field)!=1)
00181         {
00182             SG_SERROR( "missing min_value field\n") ;
00183             delete[] PEN;
00184             return NULL ;
00185         }
00186         const mxArray* mx_use_svm_field = mxGetField(mx_elem, 0, "use_svm") ;
00187         if (mx_use_svm_field==NULL || !mxIsNumeric(mx_use_svm_field) ||
00188             mxGetM(mx_use_svm_field)!=1 || mxGetN(mx_use_svm_field)!=1)
00189         {
00190             SG_SERROR( "missing use_svm field\n") ;
00191             delete[] PEN;
00192             return NULL ;
00193         }
00194         INT use_svm = (INT) mxGetScalar(mx_use_svm_field) ;
00195 
00196         const mxArray* mx_use_cache_field = mxGetField(mx_elem, 0, "use_cache") ;
00197         if (mx_use_cache_field==NULL || !mxIsNumeric(mx_use_cache_field) ||
00198             mxGetM(mx_use_cache_field)!=1 || mxGetN(mx_use_cache_field)!=1)
00199         {
00200             SG_SERROR( "missing use_cache field\n") ;
00201             delete[] PEN;
00202             return NULL ;
00203         }
00204         INT use_cache = (INT) mxGetScalar(mx_use_cache_field) ;
00205 
00206         INT id = (INT) mxGetScalar(mx_id_field)-1 ;
00207         if (i<0 || i>P-1)
00208         {
00209             SG_SERROR( "id out of range\n") ;
00210             delete[] PEN;
00211             return NULL ;
00212         }
00213         INT max_value = (INT) mxGetScalar(mx_max_value_field) ;
00214         if (max_value<-1024*1024*100 || max_value>1024*1024*100)
00215         {
00216             SG_SERROR( "max_value out of range\n") ;
00217             delete[] PEN;
00218             return NULL ;
00219         }
00220         PEN[id]->set_max_value(max_value) ;
00221 
00222         INT min_value = (INT) mxGetScalar(mx_min_value_field) ;
00223         if (min_value<-1024*1024*100 || min_value>1024*1024*100)
00224         {
00225             SG_SERROR( "min_value out of range\n") ;
00226             delete[] PEN;
00227             return NULL ;
00228         }
00229         PEN[id]->set_min_value(min_value) ;
00230         //SG_PRINT("id: %i, min_value: %i,  max_value: %i\n",id,min_value, max_value);
00231         
00232         if (PEN[id]->get_id()!=-1)
00233         {
00234             SG_SERROR( "penalty id already used\n") ;
00235             delete[] PEN;
00236             return NULL ;
00237         }
00238         PEN[id]->set_id(id) ;
00239         
00240         PEN[id]->set_use_svm(use_svm) ;
00241         PEN[id]->set_use_cache(use_cache) ;
00242 
00243         double * limits = mxGetPr(mx_limits_field) ;
00244         double * penalties = mxGetPr(mx_penalties_field) ;
00245         PEN[id]->set_plif(len, limits, penalties) ;
00246         
00247         char *transform_str = mxArrayToString(mx_transform_field) ;             
00248         char *name_str = mxArrayToString(mx_name_field) ;               
00249 
00250         if (!PEN[id]->set_transform_type(transform_str))
00251         {
00252             SG_SERROR( "transform type not recognized ('%s')\n", transform_str) ;
00253             delete[] PEN;
00254             mxFree(transform_str) ;
00255             return NULL ;
00256         }
00257 
00258         PEN[id]->set_name(name_str) ;
00259         PEN[id]->init_penalty_struct_cache() ;
00260 
00261 /*      if (PEN->cache)
00262 /           SG_SDEBUG( "penalty_info: name=%s id=%i points=%i min_value=%i max_value=%i transform='%s' (cache initialized)\n", PEN[id]->name,
00263                     PEN[id]->id, PEN[id]->len, PEN[id]->min_value, PEN[id]->max_value, transform_str) ;
00264         else
00265             SG_SDEBUG( "penalty_info: name=%s id=%i points=%i min_value=%i max_value=%i transform='%s'\n", PEN[id]->name,
00266                     PEN[id]->id, PEN[id]->len, PEN[id]->min_value, PEN[id]->max_value, transform_str) ;
00267 */
00268         
00269         mxFree(transform_str) ;
00270         mxFree(name_str) ;
00271     }
00272     return PEN ;
00273 }
00274 #endif
00275 
00276 void delete_penalty_struct(CPlif** PEN, INT P) 
00277 {
00278     for (INT i=0; i<P; i++)
00279         delete PEN[i] ;
00280     delete[] PEN ;
00281 }
00282 
00283 DREAL CPlif::lookup_penalty_svm(DREAL p_value, DREAL *d_values) const
00284 {   
00285     ASSERT(use_svm>0);
00286     DREAL d_value=d_values[use_svm-1] ;
00287 #ifdef PLIF_DEBUG
00288     SG_PRINT("%s.lookup_penalty_svm(%f)\n", get_name(), d_value) ;
00289 #endif
00290 
00291     if (!do_calc)
00292         return d_value;
00293     switch (transform)
00294     {
00295     case T_LINEAR:
00296         break ;
00297     case T_LOG:
00298         d_value = log(d_value) ;
00299         break ;
00300     case T_LOG_PLUS1:
00301         d_value = log(d_value+1) ;
00302         break ;
00303     case T_LOG_PLUS3:
00304         d_value = log(d_value+3) ;
00305         break ;
00306     case T_LINEAR_PLUS3:
00307         d_value = d_value+3 ;
00308         break ;
00309     default:
00310         SG_ERROR("unknown transform\n");
00311         break ;
00312     }
00313     
00314     INT idx = 0 ;
00315     DREAL ret ;
00316     for (INT i=0; i<len; i++)
00317         if (limits[i]<=d_value)
00318             idx++ ;
00319         else
00320             break ; // assume it is monotonically increasing
00321      
00322 #ifdef PLIF_DEBUG
00323     SG_PRINT("  -> idx = %i ", idx) ;
00324 #endif
00325     
00326     if (idx==0)
00327         ret=penalties[0] ;
00328     else if (idx==len)
00329         ret=penalties[len-1] ;
00330     else
00331     {
00332         ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]*
00333                (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ;  
00334 #ifdef PLIF_DEBUG
00335         SG_PRINT("  -> (%1.3f*%1.3f, %1.3f*%1.3f)", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ;
00336 #endif
00337     }
00338 #ifdef PLIF_DEBUG
00339         SG_PRINT("  -> ret=%1.3f\n", ret) ;
00340 #endif
00341     
00342     return ret ;
00343 }
00344 
00345 DREAL CPlif::lookup_penalty(INT p_value, DREAL* svm_values) const
00346 {
00347     if (use_svm)
00348         return lookup_penalty_svm(p_value, svm_values) ;
00349 
00350     if ((p_value<min_value) || (p_value>max_value))
00351         return -CMath::INFTY ;
00352     if (!do_calc)
00353         return p_value;
00354     if (cache!=NULL && (p_value>=0) && (p_value<=max_value))
00355     {
00356         DREAL ret=cache[p_value] ;
00357         return ret ;
00358     }
00359     return lookup_penalty((DREAL) p_value, svm_values) ;
00360 }
00361 
00362 DREAL CPlif::lookup_penalty(DREAL p_value, DREAL* svm_values) const
00363 {   
00364     if (use_svm)
00365         return lookup_penalty_svm(p_value, svm_values) ;
00366 
00367 #ifdef PLIF_DEBUG
00368     SG_PRINT("%s.lookup_penalty(%f)\n", get_name(), p_value) ;
00369 #endif
00370 
00371 
00372     if ((p_value<min_value) || (p_value>max_value))
00373         return -CMath::INFTY ;
00374 
00375     if (!do_calc)
00376         return p_value;
00377 
00378     DREAL d_value = (DREAL) p_value ;
00379     switch (transform)
00380     {
00381     case T_LINEAR:
00382         break ;
00383     case T_LOG:
00384         d_value = log(d_value) ;
00385         break ;
00386     case T_LOG_PLUS1:
00387         d_value = log(d_value+1) ;
00388         break ;
00389     case T_LOG_PLUS3:
00390         d_value = log(d_value+3) ;
00391         break ;
00392     case T_LINEAR_PLUS3:
00393         d_value = d_value+3 ;
00394         break ;
00395     default:
00396         SG_ERROR( "unknown transform\n") ;
00397         break ;
00398     }
00399 
00400 #ifdef PLIF_DEBUG
00401     SG_PRINT("  -> value = %1.4f ", d_value) ;
00402 #endif
00403 
00404     INT idx = 0 ;
00405     DREAL ret ;
00406     for (INT i=0; i<len; i++)
00407         if (limits[i]<=d_value)
00408             idx++ ;
00409         else
00410             break ; // assume it is monotonically increasing
00411     
00412 #ifdef PLIF_DEBUG
00413     SG_PRINT("  -> idx = %i ", idx) ;
00414 #endif
00415     
00416     if (idx==0)
00417         ret=penalties[0] ;
00418     else if (idx==len)
00419         ret=penalties[len-1] ;
00420     else
00421     {
00422         ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]*
00423                (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ;  
00424 #ifdef PLIF_DEBUG
00425         SG_PRINT("  -> (%1.3f*%1.3f, %1.3f*%1.3f) ", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ;
00426 #endif
00427     }
00428     //if (p_value>=30 && p_value<150)
00429     //fprintf(stderr, "%s %i(%i) -> %1.2f\n", PEN->name, p_value, idx, ret) ;
00430 #ifdef PLIF_DEBUG
00431     SG_PRINT("  -> ret=%1.3f\n", ret) ;
00432 #endif
00433     
00434     return ret ;
00435 }
00436 
00437 void CPlif::penalty_clear_derivative() 
00438 {
00439     for (INT i=0; i<len; i++)
00440         cum_derivatives[i]=0.0 ;
00441 }
00442 
00443 void CPlif::penalty_add_derivative(DREAL p_value, DREAL* svm_values) 
00444 {
00445     if (use_svm)
00446     {
00447         penalty_add_derivative_svm(p_value, svm_values) ;
00448         return ;
00449     }
00450     
00451     if ((p_value<min_value) || (p_value>max_value))
00452     {
00453         return ;
00454     }
00455     DREAL d_value = (DREAL) p_value ;
00456     switch (transform)
00457     {
00458     case T_LINEAR:
00459         break ;
00460     case T_LOG:
00461         d_value = log(d_value) ;
00462         break ;
00463     case T_LOG_PLUS1:
00464         d_value = log(d_value+1) ;
00465         break ;
00466     case T_LOG_PLUS3:
00467         d_value = log(d_value+3) ;
00468         break ;
00469     case T_LINEAR_PLUS3:
00470         d_value = d_value+3 ;
00471         break ;
00472     default:
00473         SG_ERROR( "unknown transform\n") ;
00474         break ;
00475     }
00476 
00477     INT idx = 0 ;
00478     for (INT i=0; i<len; i++)
00479         if (limits[i]<=d_value)
00480             idx++ ;
00481         else
00482             break ; // assume it is monotonically increasing
00483     
00484     if (idx==0)
00485         cum_derivatives[0]+=1 ;
00486     else if (idx==len)
00487         cum_derivatives[len-1]+=1 ;
00488     else
00489     {
00490         cum_derivatives[idx]+=(d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ;
00491         cum_derivatives[idx-1]+=(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ;
00492     }
00493 }
00494 
00495 void CPlif::penalty_add_derivative_svm(DREAL p_value, DREAL *d_values) 
00496 {   
00497     ASSERT(use_svm>0);
00498     DREAL d_value=d_values[use_svm-1] ;
00499 
00500     if (d_value<-1e+20)
00501         return;
00502     
00503     switch (transform)
00504     {
00505     case T_LINEAR:
00506         break ;
00507     case T_LOG:
00508         d_value = log(d_value) ;
00509         break ;
00510     case T_LOG_PLUS1:
00511         d_value = log(d_value+1) ;
00512         break ;
00513     case T_LOG_PLUS3:
00514         d_value = log(d_value+3) ;
00515         break ;
00516     case T_LINEAR_PLUS3:
00517         d_value = d_value+3 ;
00518         break ;
00519     default:
00520         SG_ERROR( "unknown transform\n") ;
00521         break ;
00522     }
00523     
00524     INT idx = 0 ;
00525     for (INT i=0; i<len; i++)
00526         if (limits[i]<=d_value)
00527             idx++ ;
00528         else
00529             break ; // assume it is monotonically increasing
00530     
00531     if (idx==0)
00532         cum_derivatives[0]+=1 ;
00533     else if (idx==len)
00534         cum_derivatives[len-1]+=1 ;
00535     else
00536     {
00537         cum_derivatives[idx]+=(d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ;
00538         cum_derivatives[idx-1]+=(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ;
00539     }
00540 }
00541 void CPlif::get_used_svms(INT* num_svms, INT* svm_ids)
00542 {
00543     if (use_svm)
00544     {
00545         svm_ids[(*num_svms)] = use_svm;
00546         (*num_svms)++;
00547     }
00548     SG_PRINT("->use_svm:%i plif_id:%i name:%s trans_type:%s  ",use_svm, get_id(), get_name(), get_transform_type());
00549 }
00550 bool CPlif::get_do_calc()
00551 {
00552     return do_calc;
00553 }
00554 void CPlif::set_do_calc(bool b)
00555 {
00556     do_calc = b;;
00557 }
00558 
00559 
00560 

SHOGUN Machine Learning Toolbox - Documentation