Plif.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Gunnar Raetsch
00008  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/config.h"
00012 
00013 #include <stdio.h>
00014 #include <string.h>
00015 
00016 #include "lib/io.h"
00017 
00018 #include "lib/matlab.h"
00019 
00020 #include "structure/Plif.h"
00021 
00022 //#define PLIF_DEBUG
00023 
00024 CPlif::CPlif(int32_t l)
00025 : CPlifBase()
00026 {
00027     limits=NULL;
00028     penalties=NULL;
00029     cum_derivatives=NULL;
00030     id=-1;
00031     transform=T_LINEAR;
00032     name=NULL;
00033     max_value=0;
00034     min_value=0;
00035     cache=NULL;
00036     use_svm=0;
00037     use_cache=false;
00038     len=0;
00039     do_calc = true;
00040     if (l>0)
00041         set_plif_length(l);
00042 }
00043 
00044 CPlif::~CPlif()
00045 {
00046     delete[] limits;
00047     delete[] penalties;
00048     delete[] name;
00049     delete[] cache;
00050     delete[] cum_derivatives;
00051 }
00052 
00053 bool CPlif::set_transform_type(const char *type_str)
00054 {
00055     delete[] cache ;
00056     cache=NULL ;
00057 
00058     if (strcmp(type_str, "linear")==0)
00059         transform = T_LINEAR ;
00060     else if (strcmp(type_str, "")==0)
00061         transform = T_LINEAR ;
00062     else if (strcmp(type_str, "log")==0)
00063         transform = T_LOG ;
00064     else if (strcmp(type_str, "log(+1)")==0)
00065         transform = T_LOG_PLUS1 ;
00066     else if (strcmp(type_str, "log(+3)")==0)
00067         transform = T_LOG_PLUS3 ;
00068     else if (strcmp(type_str, "(+3)")==0)
00069         transform = T_LINEAR_PLUS3 ;
00070     else
00071     {
00072         SG_ERROR( "unknown transform type (%s)\n", type_str) ;
00073         return false ;
00074     }
00075     return true ;
00076 }
00077 
00078 void CPlif::init_penalty_struct_cache()
00079 {
00080     if (!use_cache)
00081         return ;
00082     if (cache || use_svm)
00083         return ;
00084     if (max_value<=0)
00085         return ;
00086 
00087     float64_t* local_cache=new float64_t[ ((int32_t) max_value) + 2] ;
00088     
00089     if (local_cache)
00090     {
00091         for (int32_t i=0; i<=max_value; i++)
00092         {
00093             if (i<min_value)
00094                 local_cache[i] = -CMath::INFTY ;
00095             else
00096                 local_cache[i] = lookup_penalty(i, NULL) ;
00097         }
00098     }
00099     this->cache=local_cache ;
00100 }
00101 
00102 void CPlif::set_name(char *p_name)
00103 {
00104     delete[] name ;
00105     name=new char[strlen(p_name)+1] ;
00106     strcpy(name,p_name) ;
00107 }
00108 
00109 #ifdef HAVE_MATLAB
00110 CPlif** read_penalty_struct_from_cell(
00111     const mxArray * mx_penalty_info, int32_t P)
00112 {
00113     //P = mxGetN(mx_penalty_info) ;
00114     //fprintf(stderr, "p=%i size=%i\n", P, P*sizeof(CPlif)) ;
00115     
00116     CPlif** PEN = new CPlif*[P] ;
00117     for (int32_t i=0; i<P; i++)
00118         PEN[i]=new CPlif() ;
00119     
00120     for (int32_t i=0; i<P; i++)
00121     {
00122         //fprintf(stderr, "i=%i/%i\n", i, P) ;
00123         
00124         const mxArray* mx_elem = mxGetCell(mx_penalty_info, i) ;
00125         if (mx_elem==NULL || !mxIsStruct(mx_elem))
00126         {
00127             SG_SERROR("empty cell element\n") ;
00128             delete[] PEN ;
00129             return NULL ;
00130         }
00131         const mxArray* mx_id_field = mxGetField(mx_elem, 0, "id") ;
00132         if (mx_id_field==NULL || !mxIsNumeric(mx_id_field) || 
00133             mxGetN(mx_id_field)!=1 || mxGetM(mx_id_field)!=1)
00134         {
00135             SG_SERROR( "missing id field\n") ;
00136             delete[] PEN;
00137             return NULL ;
00138         }
00139         const mxArray* mx_limits_field = mxGetField(mx_elem, 0, "limits") ;
00140         if (mx_limits_field==NULL || !mxIsNumeric(mx_limits_field) ||
00141             mxGetM(mx_limits_field)!=1)
00142         {
00143             SG_SERROR( "missing limits field\n") ;
00144             delete[] PEN ;
00145             return NULL ;
00146         }
00147         int32_t len = mxGetN(mx_limits_field) ;
00148         
00149         const mxArray* mx_penalties_field = mxGetField(mx_elem, 0, "penalties") ;
00150         if (mx_penalties_field==NULL || !mxIsNumeric(mx_penalties_field) ||
00151             mxGetM(mx_penalties_field)!=1 || ((int32_t) mxGetN(mx_penalties_field))!=len)
00152         {
00153             SG_SERROR( "missing penalties field (%i)\n", i) ;
00154             delete[] PEN ;
00155             return NULL ;
00156         }
00157         const mxArray* mx_transform_field = mxGetField(mx_elem, 0, "transform") ;
00158         if (mx_transform_field==NULL || !mxIsChar(mx_transform_field))
00159         {
00160             SG_SERROR( "missing transform field\n") ;
00161             delete[] PEN;
00162             return NULL ;
00163         }
00164         const mxArray* mx_name_field = mxGetField(mx_elem, 0, "name") ;
00165         if (mx_name_field==NULL || !mxIsChar(mx_name_field))
00166         {
00167             SG_SERROR( "missing name field\n") ;
00168             delete[] PEN;
00169             return NULL ;
00170         }
00171         const mxArray* mx_max_value_field = mxGetField(mx_elem, 0, "max_value") ;
00172         if (mx_max_value_field==NULL || !mxIsNumeric(mx_max_value_field) ||
00173             mxGetM(mx_max_value_field)!=1 || mxGetN(mx_max_value_field)!=1)
00174         {
00175             SG_SERROR( "missing max_value field\n") ;
00176             delete[] PEN;
00177             return NULL ;
00178         }
00179         const mxArray* mx_min_value_field = mxGetField(mx_elem, 0, "min_value") ;
00180         if (mx_min_value_field==NULL || !mxIsNumeric(mx_min_value_field) ||
00181             mxGetM(mx_min_value_field)!=1 || mxGetN(mx_min_value_field)!=1)
00182         {
00183             SG_SERROR( "missing min_value field\n") ;
00184             delete[] PEN;
00185             return NULL ;
00186         }
00187         const mxArray* mx_use_svm_field = mxGetField(mx_elem, 0, "use_svm") ;
00188         if (mx_use_svm_field==NULL || !mxIsNumeric(mx_use_svm_field) ||
00189             mxGetM(mx_use_svm_field)!=1 || mxGetN(mx_use_svm_field)!=1)
00190         {
00191             SG_SERROR( "missing use_svm field\n") ;
00192             delete[] PEN;
00193             return NULL ;
00194         }
00195         int32_t use_svm = (int32_t) mxGetScalar(mx_use_svm_field) ;
00196 
00197         const mxArray* mx_use_cache_field = mxGetField(mx_elem, 0, "use_cache") ;
00198         if (mx_use_cache_field==NULL || !mxIsNumeric(mx_use_cache_field) ||
00199             mxGetM(mx_use_cache_field)!=1 || mxGetN(mx_use_cache_field)!=1)
00200         {
00201             SG_SERROR( "missing use_cache field\n") ;
00202             delete[] PEN;
00203             return NULL ;
00204         }
00205         int32_t use_cache = (int32_t) mxGetScalar(mx_use_cache_field) ;
00206 
00207         int32_t id = (int32_t) mxGetScalar(mx_id_field)-1 ;
00208         if (i<0 || i>P-1)
00209         {
00210             SG_SERROR( "id out of range\n") ;
00211             delete[] PEN;
00212             return NULL ;
00213         }
00214         int32_t max_value = (int32_t) mxGetScalar(mx_max_value_field) ;
00215         if (max_value<-1024*1024*100 || max_value>1024*1024*100)
00216         {
00217             SG_SERROR( "max_value out of range\n") ;
00218             delete[] PEN;
00219             return NULL ;
00220         }
00221         PEN[id]->set_max_value(max_value) ;
00222 
00223         int32_t min_value = (int32_t) mxGetScalar(mx_min_value_field) ;
00224         if (min_value<-1024*1024*100 || min_value>1024*1024*100)
00225         {
00226             SG_SERROR( "min_value out of range\n") ;
00227             delete[] PEN;
00228             return NULL ;
00229         }
00230         PEN[id]->set_min_value(min_value) ;
00231         //SG_PRINT("id: %i, min_value: %i,  max_value: %i\n",id,min_value, max_value);
00232         
00233         if (PEN[id]->get_id()!=-1)
00234         {
00235             SG_SERROR( "penalty id already used\n") ;
00236             delete[] PEN;
00237             return NULL ;
00238         }
00239         PEN[id]->set_id(id) ;
00240         
00241         PEN[id]->set_use_svm(use_svm) ;
00242         PEN[id]->set_use_cache(use_cache) ;
00243 
00244         double * limits = mxGetPr(mx_limits_field) ;
00245         double * penalties = mxGetPr(mx_penalties_field) ;
00246         PEN[id]->set_plif(len, limits, penalties) ;
00247         
00248         char *transform_str = mxArrayToString(mx_transform_field) ;             
00249         char *name_str = mxArrayToString(mx_name_field) ;               
00250 
00251         if (!PEN[id]->set_transform_type(transform_str))
00252         {
00253             SG_SERROR( "transform type not recognized ('%s')\n", transform_str) ;
00254             delete[] PEN;
00255             mxFree(transform_str) ;
00256             return NULL ;
00257         }
00258 
00259         PEN[id]->set_name(name_str) ;
00260         PEN[id]->init_penalty_struct_cache() ;
00261 
00262 /*      if (PEN->cache)
00263 /           SG_SDEBUG( "penalty_info: name=%s id=%i points=%i min_value=%i max_value=%i transform='%s' (cache initialized)\n", PEN[id]->name,
00264                     PEN[id]->id, PEN[id]->len, PEN[id]->min_value, PEN[id]->max_value, transform_str) ;
00265         else
00266             SG_SDEBUG( "penalty_info: name=%s id=%i points=%i min_value=%i max_value=%i transform='%s'\n", PEN[id]->name,
00267                     PEN[id]->id, PEN[id]->len, PEN[id]->min_value, PEN[id]->max_value, transform_str) ;
00268 */
00269         
00270         mxFree(transform_str) ;
00271         mxFree(name_str) ;
00272     }
00273     return PEN ;
00274 }
00275 #endif
00276 
00277 void delete_penalty_struct(CPlif** PEN, int32_t P) 
00278 {
00279     for (int32_t i=0; i<P; i++)
00280         delete PEN[i] ;
00281     delete[] PEN ;
00282 }
00283 
00284 float64_t CPlif::lookup_penalty_svm(
00285     float64_t p_value, float64_t *d_values) const
00286 {
00287     ASSERT(use_svm>0);
00288     float64_t d_value=d_values[use_svm-1] ;
00289 #ifdef PLIF_DEBUG
00290     SG_PRINT("%s.lookup_penalty_svm(%f)\n", get_name(), d_value) ;
00291 #endif
00292 
00293     if (!do_calc)
00294         return d_value;
00295     switch (transform)
00296     {
00297     case T_LINEAR:
00298         break ;
00299     case T_LOG:
00300         d_value = log(d_value) ;
00301         break ;
00302     case T_LOG_PLUS1:
00303         d_value = log(d_value+1) ;
00304         break ;
00305     case T_LOG_PLUS3:
00306         d_value = log(d_value+3) ;
00307         break ;
00308     case T_LINEAR_PLUS3:
00309         d_value = d_value+3 ;
00310         break ;
00311     default:
00312         SG_ERROR("unknown transform\n");
00313         break ;
00314     }
00315     
00316     int32_t idx = 0 ;
00317     float64_t ret ;
00318     for (int32_t i=0; i<len; i++)
00319         if (limits[i]<=d_value)
00320             idx++ ;
00321         else
00322             break ; // assume it is monotonically increasing
00323      
00324 #ifdef PLIF_DEBUG
00325     SG_PRINT("  -> idx = %i ", idx) ;
00326 #endif
00327     
00328     if (idx==0)
00329         ret=penalties[0] ;
00330     else if (idx==len)
00331         ret=penalties[len-1] ;
00332     else
00333     {
00334         ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]*
00335                (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ;  
00336 #ifdef PLIF_DEBUG
00337         SG_PRINT("  -> (%1.3f*%1.3f, %1.3f*%1.3f)", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ;
00338 #endif
00339     }
00340 #ifdef PLIF_DEBUG
00341         SG_PRINT("  -> ret=%1.3f\n", ret) ;
00342 #endif
00343     
00344     return ret ;
00345 }
00346 
00347 float64_t CPlif::lookup_penalty(int32_t p_value, float64_t* svm_values) const
00348 {
00349     if (use_svm)
00350         return lookup_penalty_svm(p_value, svm_values) ;
00351 
00352     if ((p_value<min_value) || (p_value>max_value))
00353         return -CMath::INFTY ;
00354     if (!do_calc)
00355         return p_value;
00356     if (cache!=NULL && (p_value>=0) && (p_value<=max_value))
00357     {
00358         float64_t ret=cache[p_value] ;
00359         return ret ;
00360     }
00361     return lookup_penalty((float64_t) p_value, svm_values) ;
00362 }
00363 
00364 float64_t CPlif::lookup_penalty(float64_t p_value, float64_t* svm_values) const
00365 {
00366     if (use_svm)
00367         return lookup_penalty_svm(p_value, svm_values) ;
00368 
00369 #ifdef PLIF_DEBUG
00370     SG_PRINT("%s.lookup_penalty(%f)\n", get_name(), p_value) ;
00371 #endif
00372 
00373 
00374     if ((p_value<min_value) || (p_value>max_value))
00375         return -CMath::INFTY ;
00376 
00377     if (!do_calc)
00378         return p_value;
00379 
00380     float64_t d_value = (float64_t) p_value ;
00381     switch (transform)
00382     {
00383     case T_LINEAR:
00384         break ;
00385     case T_LOG:
00386         d_value = log(d_value) ;
00387         break ;
00388     case T_LOG_PLUS1:
00389         d_value = log(d_value+1) ;
00390         break ;
00391     case T_LOG_PLUS3:
00392         d_value = log(d_value+3) ;
00393         break ;
00394     case T_LINEAR_PLUS3:
00395         d_value = d_value+3 ;
00396         break ;
00397     default:
00398         SG_ERROR( "unknown transform\n") ;
00399         break ;
00400     }
00401 
00402 #ifdef PLIF_DEBUG
00403     SG_PRINT("  -> value = %1.4f ", d_value) ;
00404 #endif
00405 
00406     int32_t idx = 0 ;
00407     float64_t ret ;
00408     for (int32_t i=0; i<len; i++)
00409         if (limits[i]<=d_value)
00410             idx++ ;
00411         else
00412             break ; // assume it is monotonically increasing
00413     
00414 #ifdef PLIF_DEBUG
00415     SG_PRINT("  -> idx = %i ", idx) ;
00416 #endif
00417     
00418     if (idx==0)
00419         ret=penalties[0] ;
00420     else if (idx==len)
00421         ret=penalties[len-1] ;
00422     else
00423     {
00424         ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]*
00425                (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ;  
00426 #ifdef PLIF_DEBUG
00427         SG_PRINT("  -> (%1.3f*%1.3f, %1.3f*%1.3f) ", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ;
00428 #endif
00429     }
00430     //if (p_value>=30 && p_value<150)
00431     //fprintf(stderr, "%s %i(%i) -> %1.2f\n", PEN->name, p_value, idx, ret) ;
00432 #ifdef PLIF_DEBUG
00433     SG_PRINT("  -> ret=%1.3f\n", ret) ;
00434 #endif
00435     
00436     return ret ;
00437 }
00438 
00439 void CPlif::penalty_clear_derivative() 
00440 {
00441     for (int32_t i=0; i<len; i++)
00442         cum_derivatives[i]=0.0 ;
00443 }
00444 
00445 void CPlif::penalty_add_derivative(float64_t p_value, float64_t* svm_values)
00446 {
00447     if (use_svm)
00448     {
00449         penalty_add_derivative_svm(p_value, svm_values) ;
00450         return ;
00451     }
00452     
00453     if ((p_value<min_value) || (p_value>max_value))
00454     {
00455         return ;
00456     }
00457     float64_t d_value = (float64_t) p_value ;
00458     switch (transform)
00459     {
00460     case T_LINEAR:
00461         break ;
00462     case T_LOG:
00463         d_value = log(d_value) ;
00464         break ;
00465     case T_LOG_PLUS1:
00466         d_value = log(d_value+1) ;
00467         break ;
00468     case T_LOG_PLUS3:
00469         d_value = log(d_value+3) ;
00470         break ;
00471     case T_LINEAR_PLUS3:
00472         d_value = d_value+3 ;
00473         break ;
00474     default:
00475         SG_ERROR( "unknown transform\n") ;
00476         break ;
00477     }
00478 
00479     int32_t idx = 0 ;
00480     for (int32_t i=0; i<len; i++)
00481         if (limits[i]<=d_value)
00482             idx++ ;
00483         else
00484             break ; // assume it is monotonically increasing
00485     
00486     if (idx==0)
00487         cum_derivatives[0]+=1 ;
00488     else if (idx==len)
00489         cum_derivatives[len-1]+=1 ;
00490     else
00491     {
00492         cum_derivatives[idx]+=(d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ;
00493         cum_derivatives[idx-1]+=(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ;
00494     }
00495 }
00496 
00497 void CPlif::penalty_add_derivative_svm(float64_t p_value, float64_t *d_values)
00498 {
00499     ASSERT(use_svm>0);
00500     float64_t d_value=d_values[use_svm-1] ;
00501 
00502     if (d_value<-1e+20)
00503         return;
00504     
00505     switch (transform)
00506     {
00507     case T_LINEAR:
00508         break ;
00509     case T_LOG:
00510         d_value = log(d_value) ;
00511         break ;
00512     case T_LOG_PLUS1:
00513         d_value = log(d_value+1) ;
00514         break ;
00515     case T_LOG_PLUS3:
00516         d_value = log(d_value+3) ;
00517         break ;
00518     case T_LINEAR_PLUS3:
00519         d_value = d_value+3 ;
00520         break ;
00521     default:
00522         SG_ERROR( "unknown transform\n") ;
00523         break ;
00524     }
00525     
00526     int32_t idx = 0 ;
00527     for (int32_t i=0; i<len; i++)
00528         if (limits[i]<=d_value)
00529             idx++ ;
00530         else
00531             break ; // assume it is monotonically increasing
00532     
00533     if (idx==0)
00534         cum_derivatives[0]+=1 ;
00535     else if (idx==len)
00536         cum_derivatives[len-1]+=1 ;
00537     else
00538     {
00539         cum_derivatives[idx]+=(d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ;
00540         cum_derivatives[idx-1]+=(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ;
00541     }
00542 }
00543 
00544 void CPlif::get_used_svms(int32_t* num_svms, int32_t* svm_ids)
00545 {
00546     if (use_svm)
00547     {
00548         svm_ids[(*num_svms)] = use_svm;
00549         (*num_svms)++;
00550     }
00551     SG_PRINT("->use_svm:%i plif_id:%i name:%s trans_type:%s  ",use_svm, get_id(), get_name(), get_transform_type());
00552 }
00553 
00554 bool CPlif::get_do_calc()
00555 {
00556     return do_calc;
00557 }
00558 
00559 void CPlif::set_do_calc(bool b)
00560 {
00561     do_calc = b;;
00562 }

SHOGUN Machine Learning Toolbox - Documentation