00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "lib/Signal.h"
00014 #include "base/Parallel.h"
00015
00016 #include "classifier/svm/SVM.h"
00017
00018 #include <string.h>
00019
00020 #ifndef WIN32
00021 #include <pthread.h>
00022 #endif
00023
00024 struct S_THREAD_PARAM
00025 {
00026 CSVM* svm;
00027 CLabels* result;
00028 int32_t start;
00029 int32_t end;
00030 bool verbose;
00031 };
00032
00033 CSVM::CSVM(int32_t num_sv)
00034 : CKernelMachine()
00035 {
00036 set_defaults(num_sv);
00037 }
00038
00039 CSVM::CSVM(float64_t C, CKernel* k, CLabels* lab)
00040 : CKernelMachine()
00041 {
00042 set_defaults();
00043 set_C(C,C);
00044 set_labels(lab);
00045 set_kernel(k);
00046 }
00047
00048 CSVM::~CSVM()
00049 {
00050 delete[] svm_model.alpha;
00051 delete[] svm_model.svs;
00052
00053 SG_DEBUG("SVM object destroyed\n");
00054 }
00055
00056 void CSVM::set_defaults(int32_t num_sv)
00057 {
00058 svm_model.b=0.0;
00059 svm_model.alpha=NULL;
00060 svm_model.svs=NULL;
00061 svm_model.num_svs=0;
00062 svm_loaded=false;
00063
00064 weight_epsilon=1e-5;
00065 epsilon=1e-5;
00066 tube_epsilon=1e-2;
00067
00068 nu=0.5;
00069 C1=1;
00070 C2=1;
00071 C_mkl=0;
00072 mkl_norm=1;
00073
00074 objective=0;
00075
00076 qpsize=41;
00077 use_bias=true;
00078 use_shrinking=true;
00079 use_mkl=false;
00080 use_batch_computation=true;
00081 use_linadd=true;
00082
00083 if (num_sv>0)
00084 create_new_model(num_sv);
00085 }
00086
00087 bool CSVM::load(FILE* modelfl)
00088 {
00089 bool result=true;
00090 char char_buffer[1024];
00091 int32_t int_buffer;
00092 float64_t double_buffer;
00093 int32_t line_number=1;
00094
00095 if (fscanf(modelfl,"%4s\n", char_buffer)==EOF)
00096 {
00097 result=false;
00098 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00099 }
00100 else
00101 {
00102 char_buffer[4]='\0';
00103 if (strcmp("%SVM", char_buffer)!=0)
00104 {
00105 result=false;
00106 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00107 }
00108 line_number++;
00109 }
00110
00111 int_buffer=0;
00112 if (fscanf(modelfl," numsv=%d; \n", &int_buffer) != 1)
00113 {
00114 result=false;
00115 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00116 }
00117
00118 if (!feof(modelfl))
00119 line_number++;
00120
00121 SG_INFO( "loading %ld support vectors\n",int_buffer);
00122 create_new_model(int_buffer);
00123
00124 if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00125 {
00126 result=false;
00127 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00128 }
00129
00130 if (!feof(modelfl))
00131 line_number++;
00132
00133 double_buffer=0;
00134
00135 if (fscanf(modelfl," b=%lf; \n", &double_buffer) != 1)
00136 {
00137 result=false;
00138 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00139 }
00140
00141 if (!feof(modelfl))
00142 line_number++;
00143
00144 set_bias(double_buffer);
00145
00146 if (fscanf(modelfl,"%8s\n", char_buffer) == EOF)
00147 {
00148 result=false;
00149 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00150 }
00151 else
00152 {
00153 char_buffer[9]='\0';
00154 if (strcmp("alphas=[", char_buffer)!=0)
00155 {
00156 result=false;
00157 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00158 }
00159 line_number++;
00160 }
00161
00162 for (int32_t i=0; i<get_num_support_vectors(); i++)
00163 {
00164 double_buffer=0;
00165 int_buffer=0;
00166
00167 if (fscanf(modelfl," \[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00168 {
00169 result=false;
00170 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00171 }
00172
00173 if (!feof(modelfl))
00174 line_number++;
00175
00176 set_support_vector(i, int_buffer);
00177 set_alpha(i, double_buffer);
00178 }
00179
00180 if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00181 {
00182 result=false;
00183 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00184 }
00185 else
00186 {
00187 char_buffer[3]='\0';
00188 if (strcmp("];", char_buffer)!=0)
00189 {
00190 result=false;
00191 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00192 }
00193 line_number++;
00194 }
00195
00196 svm_loaded=result;
00197 return result;
00198 }
00199
00200 bool CSVM::save(FILE* modelfl)
00201 {
00202 if (!kernel)
00203 SG_ERROR("Kernel not defined!\n");
00204
00205 SG_INFO( "Writing model file...");
00206 fprintf(modelfl,"%%SVM\n");
00207 fprintf(modelfl,"numsv=%d;\n", get_num_support_vectors());
00208 fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
00209 fprintf(modelfl,"b=%+10.16e;\n",get_bias());
00210
00211 fprintf(modelfl, "alphas=\[\n");
00212
00213 for(int32_t i=0; i<get_num_support_vectors(); i++)
00214 fprintf(modelfl,"\t[%+10.16e,%d];\n",
00215 CSVM::get_alpha(i), get_support_vector(i));
00216
00217 fprintf(modelfl, "];\n");
00218
00219 SG_DONE();
00220 return true ;
00221 }
00222
00223 bool CSVM::init_kernel_optimization()
00224 {
00225 int32_t num_sv=get_num_support_vectors();
00226
00227 if (kernel && kernel->has_property(KP_LINADD) && num_sv>0)
00228 {
00229 int32_t * sv_idx = new int32_t[num_sv] ;
00230 float64_t* sv_weight = new float64_t[num_sv] ;
00231
00232 for(int32_t i=0; i<num_sv; i++)
00233 {
00234 sv_idx[i] = get_support_vector(i) ;
00235 sv_weight[i] = get_alpha(i) ;
00236 }
00237
00238 bool ret = kernel->init_optimization(num_sv, sv_idx, sv_weight) ;
00239
00240 delete[] sv_idx ;
00241 delete[] sv_weight ;
00242
00243 if (!ret)
00244 SG_ERROR( "initialization of kernel optimization failed\n");
00245
00246 return ret;
00247 }
00248 else
00249 SG_ERROR( "initialization of kernel optimization failed\n");
00250
00251 return false;
00252 }
00253
00254 void* CSVM::classify_example_helper(void* p)
00255 {
00256 S_THREAD_PARAM* params= (S_THREAD_PARAM*) p;
00257 CLabels* result=params->result;
00258 CSVM* svm=params->svm;
00259
00260 #ifdef WIN32
00261 for (int32_t vec=params->start; vec<params->end; vec++)
00262 #else
00263 CSignal::clear_cancel();
00264 for (int32_t vec=params->start; vec<params->end &&
00265 !CSignal::cancel_computations(); vec++)
00266 #endif
00267 {
00268 if (params->verbose)
00269 {
00270 int32_t num_vectors=params->end - params->start;
00271 int32_t v=vec-params->start;
00272 if ( (v% (num_vectors/100+1))== 0)
00273 SG_SPROGRESS(v, 0.0, num_vectors-1);
00274 }
00275
00276 result->set_label(vec, svm->classify_example(vec));
00277 }
00278
00279 return NULL;
00280 }
00281
00282 CLabels* CSVM::classify(CLabels* lab)
00283 {
00284 if (!kernel)
00285 {
00286 SG_ERROR( "SVM can not proceed without kernel!\n");
00287 return false ;
00288 }
00289
00290 if ( kernel && kernel->get_num_vec_rhs()>0 )
00291 {
00292 int32_t num_vectors=kernel->get_num_vec_rhs();
00293
00294 if (!lab)
00295 lab=new CLabels(num_vectors);
00296 SG_DEBUG( "computing output on %d test examples\n", num_vectors);
00297
00298 if (this->io.get_show_progress())
00299 this->io.enable_progress();
00300 else
00301 this->io.disable_progress();
00302
00303 if (kernel->has_property(KP_BATCHEVALUATION) &&
00304 get_batch_computation_enabled())
00305 {
00306 ASSERT(get_num_support_vectors()>0);
00307 int32_t* sv_idx=new int32_t[get_num_support_vectors()];
00308 float64_t* sv_weight=new float64_t[get_num_support_vectors()];
00309 int32_t* idx=new int32_t[num_vectors];
00310 float64_t* output=new float64_t[num_vectors];
00311 memset(output, 0, sizeof(float64_t)*num_vectors);
00312
00313
00314 for (int32_t i=0; i<num_vectors; i++)
00315 idx[i]=i;
00316
00317 for (int32_t i=0; i<get_num_support_vectors(); i++)
00318 {
00319 sv_idx[i] = get_support_vector(i) ;
00320 sv_weight[i] = get_alpha(i) ;
00321 }
00322
00323 kernel->compute_batch(num_vectors, idx,
00324 output, get_num_support_vectors(), sv_idx, sv_weight);
00325
00326 for (int32_t i=0; i<num_vectors; i++)
00327 lab->set_label(i, get_bias()+output[i]);
00328
00329 delete[] sv_idx ;
00330 delete[] sv_weight ;
00331 delete[] idx;
00332 delete[] output;
00333 }
00334 else
00335 {
00336 int32_t num_threads=parallel.get_num_threads();
00337 ASSERT(num_threads>0);
00338
00339 if (num_threads < 2)
00340 {
00341 S_THREAD_PARAM params;
00342 params.svm=this;
00343 params.result=lab;
00344 params.start=0;
00345 params.end=num_vectors;
00346 params.verbose=true;
00347 classify_example_helper((void*) ¶ms);
00348 }
00349 #ifndef WIN32
00350 else
00351 {
00352 pthread_t threads[num_threads-1];
00353 S_THREAD_PARAM params[num_threads];
00354 int32_t step= num_vectors/num_threads;
00355
00356 int32_t t;
00357
00358 for (t=0; t<num_threads-1; t++)
00359 {
00360 params[t].svm = this;
00361 params[t].result = lab;
00362 params[t].start = t*step;
00363 params[t].end = (t+1)*step;
00364 params[t].verbose = false;
00365 pthread_create(&threads[t], NULL,
00366 CSVM::classify_example_helper, (void*)¶ms[t]);
00367 }
00368
00369 params[t].svm = this;
00370 params[t].result = lab;
00371 params[t].start = t*step;
00372 params[t].end = num_vectors;
00373 params[t].verbose = true;
00374 classify_example_helper((void*) ¶ms[t]);
00375
00376 for (t=0; t<num_threads-1; t++)
00377 pthread_join(threads[t], NULL);
00378 }
00379 #endif
00380 }
00381
00382 #ifndef WIN32
00383 if ( CSignal::cancel_computations() )
00384 SG_INFO( "prematurely stopped. \n");
00385 else
00386 #endif
00387 SG_DONE();
00388 }
00389 else
00390 return NULL;
00391
00392 return lab;
00393 }
00394
00395 float64_t CSVM::classify_example(int32_t num)
00396 {
00397 ASSERT(kernel);
00398
00399 if (kernel->has_property(KP_LINADD) && (kernel->get_is_initialized()))
00400 {
00401 float64_t dist = kernel->compute_optimized(num);
00402 return (dist+get_bias());
00403 }
00404 else
00405 {
00406 float64_t dist=0;
00407 for(int32_t i=0; i<get_num_support_vectors(); i++)
00408 dist+=kernel->kernel(get_support_vector(i), num)*get_alpha(i);
00409
00410 return (dist+get_bias());
00411 }
00412 }
00413
00414
00415 float64_t CSVM::compute_objective()
00416 {
00417 int32_t n=get_num_support_vectors();
00418
00419 if (labels && kernel)
00420 {
00421 objective=0;
00422 for (int32_t i=0; i<n; i++)
00423 {
00424 objective-=get_alpha(i)*labels->get_label(i);
00425 for (int32_t j=0; j<n; j++)
00426 objective+=0.5*get_alpha(i)*get_alpha(j)*kernel->kernel(i,j);
00427 }
00428 }
00429 else
00430 SG_ERROR( "cannot compute objective, labels or kernel not set\n");
00431
00432 return objective;
00433 }