MPDSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Soeren Sonnenburg
00008  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "classifier/svm/MPDSVM.h"
00012 #include "lib/io.h"
00013 #include "lib/common.h"
00014 #include "lib/Mathematics.h"
00015 
00016 CMPDSVM::CMPDSVM()
00017 : CSVM()
00018 {
00019 }
00020 
00021 CMPDSVM::CMPDSVM(DREAL C, CKernel* k, CLabels* lab)
00022 : CSVM(C, k, lab)
00023 {
00024 }
00025 
00026 CMPDSVM::~CMPDSVM()
00027 {
00028 }
00029 
00030 bool CMPDSVM::train()
00031 {
00032     ASSERT(labels);
00033     ASSERT(kernel && kernel->has_features());
00034     //const DREAL nu=0.32;
00035     const DREAL alpha_eps=1e-12;
00036     const DREAL eps=get_epsilon();
00037     const long int maxiter = 1L<<30;
00038     //const bool nustop=false;
00039     //const int k=2;
00040     const int n=labels->get_num_labels();
00041     ASSERT(n>0);
00042     //const DREAL d = 1.0/n/nu; //NUSVC
00043     const DREAL d = get_C1(); //CSVC
00044     const DREAL primaleps=eps;
00045     const DREAL dualeps=eps*n; //heuristic
00046     long int niter=0;
00047 
00048     kernel_cache = new CCache<KERNELCACHE_ELEM>(kernel->get_cache_size(), n, n);
00049     DREAL* alphas=new DREAL[n];
00050     DREAL* dalphas=new DREAL[n];
00051     //DREAL* hessres=new DREAL[2*n];
00052     DREAL* hessres=new DREAL[n];
00053     //DREAL* F=new DREAL[2*n];
00054     DREAL* F=new DREAL[n];
00055 
00056     //DREAL hessest[2]={0,0};
00057     //DREAL hstep[2];
00058     //DREAL etas[2]={0,0};
00059     //DREAL detas[2]={0,1}; //NUSVC
00060     DREAL etas=0;
00061     DREAL detas=0;   //CSVC
00062     DREAL hessest=0;
00063     DREAL hstep;
00064 
00065     const DREAL stopfac = 1;
00066 
00067     bool primalcool;
00068     bool dualcool;
00069 
00070     //if (nustop)
00071     //etas[1] = 1;
00072 
00073     for (int i=0; i<n; i++)
00074     {
00075         alphas[i]=0;
00076         F[i]=labels->get_label(i);
00077         //F[i+n]=-1;
00078         hessres[i]=labels->get_label(i);
00079         //hessres[i+n]=-1;
00080         //dalphas[i]=F[i+n]*etas[1]; //NUSVC
00081         dalphas[i]=-1; //CSVC
00082     }
00083 
00084     // go ...
00085     while (niter++ < maxiter)
00086     {
00087         int maxpidx=-1;
00088         DREAL maxpviol = -1;
00089         //DREAL maxdviol = CMath::abs(detas[0]);
00090         DREAL maxdviol = CMath::abs(detas);
00091         bool free_alpha=false;
00092 
00093         //if (CMath::abs(detas[1])> maxdviol)
00094         //maxdviol=CMath::abs(detas[1]);
00095 
00096         // compute kkt violations with correct sign ...
00097         for (int i=0; i<n; i++)
00098         {
00099             DREAL v=CMath::abs(dalphas[i]);
00100 
00101             if (alphas[i] > 0 && alphas[i] < d)
00102                 free_alpha=true;
00103 
00104             if ( (dalphas[i]==0) ||
00105                     (alphas[i]==0 && dalphas[i] >0) ||
00106                     (alphas[i]==d && dalphas[i] <0)
00107                )
00108                 v=0;
00109 
00110             if (v > maxpviol)
00111             {
00112                 maxpviol=v;
00113                 maxpidx=i;
00114             } // if we cannot improve on maxpviol, we can still improve by choosing a cached element
00115             else if (v == maxpviol) 
00116             {
00117                 if (kernel_cache->is_cached(i))
00118                     maxpidx=i;
00119             }
00120         }
00121 
00122         if (maxpidx<0 || maxdviol<0)
00123             SG_ERROR( "no violation no convergence, should not happen!\n");
00124 
00125         // ... and evaluate stopping conditions
00126         //if (nustop)
00127         //stopfac = CMath::max(etas[1], 1e-10);    
00128         //else
00129         //stopfac = 1;
00130 
00131         if (niter%10000 == 0)
00132         {
00133             DREAL obj=0;
00134 
00135             for (int i=0; i<n; i++)
00136             {
00137                 obj-=alphas[i];
00138                 for (int j=0; j<n; j++)
00139                     obj+=0.5*labels->get_label(i)*labels->get_label(j)*alphas[i]*alphas[j]*kernel->kernel(i,j);
00140             }
00141 
00142             SG_DEBUG( "obj:%f pviol:%f dviol:%f maxpidx:%d iter:%d\n", obj, maxpviol, maxdviol, maxpidx, niter);
00143         }
00144 
00145         //for (int i=0; i<n; i++)
00146         //  SG_DEBUG( "alphas:%f dalphas:%f\n", alphas[i], dalphas[i]);
00147 
00148         primalcool = (maxpviol < primaleps*stopfac);
00149         dualcool = (maxdviol < dualeps*stopfac) || (!free_alpha);
00150 
00151         // done?
00152         if (primalcool && dualcool)
00153         {
00154             if (!free_alpha)
00155                 SG_INFO( " no free alpha, stopping! #iter=%d\n", niter);
00156             else
00157                 SG_INFO( " done! #iter=%d\n", niter);
00158             break;
00159         }
00160 
00161 
00162         ASSERT(maxpidx>=0 && maxpidx<n);
00163         // hessian updates
00164         hstep=-hessres[maxpidx]/compute_H(maxpidx,maxpidx);
00165         //hstep[0]=-hessres[maxpidx]/(compute_H(maxpidx,maxpidx)+hessreg);
00166         //hstep[1]=-hessres[maxpidx+n]/(compute_H(maxpidx,maxpidx)+hessreg);
00167 
00168         hessest-=F[maxpidx]*hstep;
00169         //hessest[0]-=F[maxpidx]*hstep[0];
00170         //hessest[1]-=F[maxpidx+n]*hstep[1];
00171 
00172         // do primal updates ..
00173         DREAL tmpalpha = alphas[maxpidx] - dalphas[maxpidx]/compute_H(maxpidx,maxpidx);
00174 
00175         if (tmpalpha > d-alpha_eps) 
00176             tmpalpha = d;
00177 
00178         if (tmpalpha < 0+alpha_eps)
00179             tmpalpha = 0;
00180 
00181         // update alphas & dalphas & detas ...
00182         DREAL alphachange = tmpalpha - alphas[maxpidx];
00183         alphas[maxpidx] = tmpalpha;
00184 
00185         KERNELCACHE_ELEM* h=lock_kernel_row(maxpidx);
00186         for (int i=0; i<n; i++)
00187         {
00188             hessres[i]+=h[i]*hstep;
00189             //hessres[i]+=h[i]*hstep[0];
00190             //hessres[i+n]+=h[i]*hstep[1];
00191             dalphas[i] +=h[i]*alphachange;
00192         }
00193         unlock_kernel_row(maxpidx);
00194 
00195         detas+=F[maxpidx]*alphachange;
00196         //detas[0]+=F[maxpidx]*alphachange;
00197         //detas[1]+=F[maxpidx+n]*alphachange;
00198 
00199         // if at primal minimum, do eta update ...            
00200         if (primalcool)
00201         {
00202             //DREAL etachange[2] = { detas[0]/hessest[0] , detas[1]/hessest[1] };
00203             DREAL etachange = detas/hessest;
00204 
00205             etas+=etachange;        
00206             //etas[0]+=etachange[0];        
00207             //etas[1]+=etachange[1];        
00208 
00209             // update dalphas
00210             for (int i=0; i<n; i++)
00211                 dalphas[i]+= F[i] * etachange;
00212             //dalphas[i]+= F[i] * etachange[0] + F[i+n] * etachange[1];
00213         }
00214     }
00215 
00216     if (niter >= maxiter)
00217         SG_WARNING( "increase maxiter ... \n");
00218 
00219 
00220     int nsv=0;
00221     for (int i=0; i<n; i++)
00222     {
00223         if (alphas[i]>0)
00224             nsv++;
00225     }
00226 
00227 
00228     create_new_model(nsv);
00229     //set_bias(etas[0]/etas[1]);
00230     set_bias(etas);
00231 
00232     int j=0;
00233     for (int i=0; i<n; i++)
00234     {
00235         if (alphas[i]>0)
00236         {
00237             //set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]);
00238             set_alpha(j, alphas[i]*labels->get_label(i));
00239             set_support_vector(j, i);
00240             j++;
00241         }
00242     }
00243     compute_objective();
00244     SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias());
00245     SG_INFO( "Number of SV: %ld\n", get_num_support_vectors());
00246 
00247     delete[] alphas;
00248     delete[] dalphas;
00249     delete[] hessres;
00250     delete[] F;
00251     delete kernel_cache;
00252 
00253     return true;
00254 }

SHOGUN Machine Learning Toolbox - Documentation