MPDSVM.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "classifier/svm/MPDSVM.h"
00012 #include "lib/io.h"
00013 #include "lib/common.h"
00014 #include "lib/Mathematics.h"
00015
00016 CMPDSVM::CMPDSVM()
00017 : CSVM()
00018 {
00019 }
00020
00021 CMPDSVM::CMPDSVM(float64_t C, CKernel* k, CLabels* lab)
00022 : CSVM(C, k, lab)
00023 {
00024 }
00025
00026 CMPDSVM::~CMPDSVM()
00027 {
00028 }
00029
00030 bool CMPDSVM::train()
00031 {
00032 ASSERT(labels);
00033 ASSERT(kernel && kernel->has_features());
00034
00035 const float64_t alpha_eps=1e-12;
00036 const float64_t eps=get_epsilon();
00037 const int64_t maxiter = 1L<<30;
00038
00039
00040 const int32_t n=labels->get_num_labels();
00041 ASSERT(n>0);
00042
00043 const float64_t d = get_C1();
00044 const float64_t primaleps=eps;
00045 const float64_t dualeps=eps*n;
00046 int64_t niter=0;
00047
00048 kernel_cache = new CCache<KERNELCACHE_ELEM>(kernel->get_cache_size(), n, n);
00049 float64_t* alphas=new float64_t[n];
00050 float64_t* dalphas=new float64_t[n];
00051
00052 float64_t* hessres=new float64_t[n];
00053
00054 float64_t* F=new float64_t[n];
00055
00056
00057
00058
00059
00060 float64_t etas=0;
00061 float64_t detas=0;
00062 float64_t hessest=0;
00063 float64_t hstep;
00064
00065 const float64_t stopfac = 1;
00066
00067 bool primalcool;
00068 bool dualcool;
00069
00070
00071
00072
00073 for (int32_t i=0; i<n; i++)
00074 {
00075 alphas[i]=0;
00076 F[i]=labels->get_label(i);
00077
00078 hessres[i]=labels->get_label(i);
00079
00080
00081 dalphas[i]=-1;
00082 }
00083
00084
00085 while (niter++ < maxiter)
00086 {
00087 int32_t maxpidx=-1;
00088 float64_t maxpviol = -1;
00089
00090 float64_t maxdviol = CMath::abs(detas);
00091 bool free_alpha=false;
00092
00093
00094
00095
00096
00097 for (int32_t i=0; i<n; i++)
00098 {
00099 float64_t v=CMath::abs(dalphas[i]);
00100
00101 if (alphas[i] > 0 && alphas[i] < d)
00102 free_alpha=true;
00103
00104 if ( (dalphas[i]==0) ||
00105 (alphas[i]==0 && dalphas[i] >0) ||
00106 (alphas[i]==d && dalphas[i] <0)
00107 )
00108 v=0;
00109
00110 if (v > maxpviol)
00111 {
00112 maxpviol=v;
00113 maxpidx=i;
00114 }
00115 else if (v == maxpviol)
00116 {
00117 if (kernel_cache->is_cached(i))
00118 maxpidx=i;
00119 }
00120 }
00121
00122 if (maxpidx<0 || maxdviol<0)
00123 SG_ERROR( "no violation no convergence, should not happen!\n");
00124
00125
00126
00127
00128
00129
00130
00131 if (niter%10000 == 0)
00132 {
00133 float64_t obj=0;
00134
00135 for (int32_t i=0; i<n; i++)
00136 {
00137 obj-=alphas[i];
00138 for (int32_t j=0; j<n; j++)
00139 obj+=0.5*labels->get_label(i)*labels->get_label(j)*alphas[i]*alphas[j]*kernel->kernel(i,j);
00140 }
00141
00142 SG_DEBUG( "obj:%f pviol:%f dviol:%f maxpidx:%d iter:%d\n", obj, maxpviol, maxdviol, maxpidx, niter);
00143 }
00144
00145
00146
00147
00148 primalcool = (maxpviol < primaleps*stopfac);
00149 dualcool = (maxdviol < dualeps*stopfac) || (!free_alpha);
00150
00151
00152 if (primalcool && dualcool)
00153 {
00154 if (!free_alpha)
00155 SG_INFO( " no free alpha, stopping! #iter=%d\n", niter);
00156 else
00157 SG_INFO( " done! #iter=%d\n", niter);
00158 break;
00159 }
00160
00161
00162 ASSERT(maxpidx>=0 && maxpidx<n);
00163
00164 hstep=-hessres[maxpidx]/compute_H(maxpidx,maxpidx);
00165
00166
00167
00168 hessest-=F[maxpidx]*hstep;
00169
00170
00171
00172
00173 float64_t tmpalpha = alphas[maxpidx] - dalphas[maxpidx]/compute_H(maxpidx,maxpidx);
00174
00175 if (tmpalpha > d-alpha_eps)
00176 tmpalpha = d;
00177
00178 if (tmpalpha < 0+alpha_eps)
00179 tmpalpha = 0;
00180
00181
00182 float64_t alphachange = tmpalpha - alphas[maxpidx];
00183 alphas[maxpidx] = tmpalpha;
00184
00185 KERNELCACHE_ELEM* h=lock_kernel_row(maxpidx);
00186 for (int32_t i=0; i<n; i++)
00187 {
00188 hessres[i]+=h[i]*hstep;
00189
00190
00191 dalphas[i] +=h[i]*alphachange;
00192 }
00193 unlock_kernel_row(maxpidx);
00194
00195 detas+=F[maxpidx]*alphachange;
00196
00197
00198
00199
00200 if (primalcool)
00201 {
00202
00203 float64_t etachange = detas/hessest;
00204
00205 etas+=etachange;
00206
00207
00208
00209
00210 for (int32_t i=0; i<n; i++)
00211 dalphas[i]+= F[i] * etachange;
00212
00213 }
00214 }
00215
00216 if (niter >= maxiter)
00217 SG_WARNING( "increase maxiter ... \n");
00218
00219
00220 int32_t nsv=0;
00221 for (int32_t i=0; i<n; i++)
00222 {
00223 if (alphas[i]>0)
00224 nsv++;
00225 }
00226
00227
00228 create_new_model(nsv);
00229
00230 set_bias(etas);
00231
00232 int32_t j=0;
00233 for (int32_t i=0; i<n; i++)
00234 {
00235 if (alphas[i]>0)
00236 {
00237
00238 set_alpha(j, alphas[i]*labels->get_label(i));
00239 set_support_vector(j, i);
00240 j++;
00241 }
00242 }
00243 compute_objective();
00244 SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias());
00245 SG_INFO( "Number of SV: %ld\n", get_num_support_vectors());
00246
00247 delete[] alphas;
00248 delete[] dalphas;
00249 delete[] hessres;
00250 delete[] F;
00251 delete kernel_cache;
00252
00253 return true;
00254 }